mirror of https://github.com/hpcaitech/ColossalAI
Optimize pipeline schedule (#94)
* add pipeline shared module wrapper and update load batch * added model parallel process group for amp and clip grad (#86) * added model parallel process group for amp and clip grad * update amp and clip with model parallel process group * remove pipeline_prev/next group (#88) * micro batch offload * optimize pipeline gpu memory usage * pipeline can receive tensor shape (#93) * optimize pipeline gpu memory usage * fix grad accumulation step counter * rename classes and functions Co-authored-by: Frank Lee <somerlee.9@gmail.com>pull/97/head
parent
e5b9f9a08d
commit
96780e6ee4
|
@ -359,12 +359,7 @@ class FP16Optimizer(Optimizer):
|
||||||
# Update across all model parallel instances.
|
# Update across all model parallel instances.
|
||||||
torch.distributed.all_reduce(self.found_inf,
|
torch.distributed.all_reduce(self.found_inf,
|
||||||
op=torch.distributed.ReduceOp.MAX,
|
op=torch.distributed.ReduceOp.MAX,
|
||||||
group=gpc.get_group(ParallelMode.TENSOR))
|
group=gpc.get_group(ParallelMode.MODEL))
|
||||||
|
|
||||||
if is_using_pp():
|
|
||||||
torch.distributed.all_reduce(self.found_inf,
|
|
||||||
op=torch.distributed.ReduceOp.MAX,
|
|
||||||
group=gpc.get_group(ParallelMode.PIPELINE))
|
|
||||||
|
|
||||||
# Check for nan.
|
# Check for nan.
|
||||||
found_inf_flag = (self.found_inf.item() > 0)
|
found_inf_flag = (self.found_inf.item() > 0)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
|
|
||||||
class _MultiDeviceReplicator(object):
|
class _MultiDeviceReplicator(object):
|
||||||
|
@ -247,10 +248,14 @@ class GradScaler(object):
|
||||||
device),
|
device),
|
||||||
per_device_inv_scale.get(device))
|
per_device_inv_scale.get(device))
|
||||||
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
||||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||||
for tensor in per_device_found_inf._per_device_tensors.values():
|
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
|
||||||
dist.all_reduce(tensor, op=dist.ReduceOp.MAX,
|
coalesced = _flatten_dense_tensors(vals)
|
||||||
group=gpc.get_group(ParallelMode.TENSOR))
|
dist.all_reduce(coalesced,
|
||||||
|
op=dist.ReduceOp.MAX,
|
||||||
|
group=gpc.get_group(ParallelMode.MODEL))
|
||||||
|
for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)):
|
||||||
|
buf.copy_(synced)
|
||||||
return per_device_found_inf._per_device_tensors
|
return per_device_found_inf._per_device_tensors
|
||||||
|
|
||||||
def unscale_(self, optimizer):
|
def unscale_(self, optimizer):
|
||||||
|
|
|
@ -112,7 +112,7 @@ def _binary_search(weights, num):
|
||||||
return intervals
|
return intervals
|
||||||
|
|
||||||
|
|
||||||
def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||||
assert num_items % num_chunks == 0, \
|
assert num_items % num_chunks == 0, \
|
||||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||||||
|
|
||||||
|
@ -134,11 +134,11 @@ def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||||||
num_total = pipeline_parallel_size * num_chunks
|
num_total = pipeline_parallel_size * num_chunks
|
||||||
num_items = len(weights)
|
num_items = len(weights)
|
||||||
if num_items <= num_total:
|
if num_items <= num_total:
|
||||||
return _partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
||||||
|
|
||||||
intervals = _binary_search(weights, num_total)
|
intervals = _binary_search(weights, num_total)
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
def _count_layer_params(layers):
|
def count_layer_params(layers):
|
||||||
"""Count the number of parameters in each layer
|
"""Count the number of parameters in each layer
|
||||||
"""
|
"""
|
||||||
param_counts = [0] * len(layers)
|
param_counts = [0] * len(layers)
|
||||||
|
@ -201,11 +201,11 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method:
|
||||||
# Make a partition
|
# Make a partition
|
||||||
if method == 'layer':
|
if method == 'layer':
|
||||||
num_layers = len(layers)
|
num_layers = len(layers)
|
||||||
parts = _partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
|
parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
|
||||||
elif method == 'parameter':
|
elif method == 'parameter':
|
||||||
param_counts = _count_layer_params(layers)
|
param_counts = count_layer_params(layers)
|
||||||
# print_rank_0(param_counts)
|
# print_rank_0(param_counts)
|
||||||
parts = _partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
|
parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Method should be a pre-set string in [layer, parameter]")
|
raise ValueError("Method should be a pre-set string in [layer, parameter]")
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
|
||||||
"""
|
"""
|
||||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
partitions = _partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
||||||
module_list = []
|
module_list = []
|
||||||
for start, end in partitions[pipeline_rank]:
|
for start, end in partitions[pipeline_rank]:
|
||||||
module_list.append(nn.Sequential(*layers[start:end]))
|
module_list.append(nn.Sequential(*layers[start:end]))
|
||||||
|
|
|
@ -14,7 +14,8 @@ INITIALIZER_MAPPING = {
|
||||||
'2d': 'Initializer_2D',
|
'2d': 'Initializer_2D',
|
||||||
'2.5d': 'Initializer_2p5D',
|
'2.5d': 'Initializer_2p5D',
|
||||||
'3d': 'Initializer_3D',
|
'3d': 'Initializer_3D',
|
||||||
'sequence': 'Initializer_Sequence'
|
'sequence': 'Initializer_Sequence',
|
||||||
|
'model': 'Initializer_Model'
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1D parallel
|
# 1D parallel
|
||||||
|
|
|
@ -394,6 +394,9 @@ class ParallelContext:
|
||||||
# LSG: init data parallel process group for compatibility with other parallel module such as zero
|
# LSG: init data parallel process group for compatibility with other parallel module such as zero
|
||||||
pg_init.append(dict(type=INITIALIZER_MAPPING['data']))
|
pg_init.append(dict(type=INITIALIZER_MAPPING['data']))
|
||||||
|
|
||||||
|
# LSG: init model parallel process group for compatibility with amp and clip grad
|
||||||
|
pg_init.append(dict(type=INITIALIZER_MAPPING['model']))
|
||||||
|
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
|
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
|
||||||
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))
|
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))
|
||||||
|
|
|
@ -14,10 +14,12 @@ class ParallelMode(Enum):
|
||||||
# common parallel
|
# common parallel
|
||||||
DATA = 'data'
|
DATA = 'data'
|
||||||
|
|
||||||
|
# model parallel - containing tensor and pipeline parallel groups
|
||||||
|
# this is added to facilitate amp and grad clipping in hybrid parallel
|
||||||
|
MODEL = 'model'
|
||||||
|
|
||||||
# pipeline parallel
|
# pipeline parallel
|
||||||
PIPELINE = 'pipe'
|
PIPELINE = 'pipe'
|
||||||
PIPELINE_PREV = 'pipe_prev'
|
|
||||||
PIPELINE_NEXT = 'pipe_next'
|
|
||||||
|
|
||||||
# containing all ranks in tensor parallel
|
# containing all ranks in tensor parallel
|
||||||
TENSOR = 'tensor'
|
TENSOR = 'tensor'
|
||||||
|
|
|
@ -6,10 +6,11 @@ from .initializer_data import Initializer_Data
|
||||||
from .initializer_pipeline import Initializer_Pipeline
|
from .initializer_pipeline import Initializer_Pipeline
|
||||||
from .initializer_sequence import Initializer_Sequence
|
from .initializer_sequence import Initializer_Sequence
|
||||||
from .initializer_tensor import Initializer_Tensor
|
from .initializer_tensor import Initializer_Tensor
|
||||||
|
from .initializer_model import Initializer_Model
|
||||||
from .process_group_initializer import ProcessGroupInitializer
|
from .process_group_initializer import ProcessGroupInitializer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
|
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
|
||||||
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
|
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
|
||||||
'Initializer_1D', 'ProcessGroupInitializer'
|
'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from colossalai.context import Config
|
||||||
|
from colossalai.registry import DIST_GROUP_INITIALIZER
|
||||||
|
from .process_group_initializer import ProcessGroupInitializer
|
||||||
|
from ..parallel_mode import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
|
@DIST_GROUP_INITIALIZER.register_module
|
||||||
|
class Initializer_Model(ProcessGroupInitializer):
|
||||||
|
'''A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel groups).
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model_parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size
|
||||||
|
self.num_group = self.world_size // self.model_parallel_size
|
||||||
|
|
||||||
|
def init_dist_group(self):
|
||||||
|
'''Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
|
||||||
|
|
||||||
|
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
|
||||||
|
:rtype: tuple
|
||||||
|
'''
|
||||||
|
local_rank = None
|
||||||
|
ranks_in_group = None
|
||||||
|
process_group = None
|
||||||
|
group_world_size = None
|
||||||
|
mode = ParallelMode.MODEL
|
||||||
|
|
||||||
|
for i in range(self.num_group):
|
||||||
|
ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)]
|
||||||
|
group = dist.new_group(ranks)
|
||||||
|
|
||||||
|
if self.rank in ranks:
|
||||||
|
local_rank = ranks.index(self.rank)
|
||||||
|
group_world_size = len(ranks)
|
||||||
|
process_group = group
|
||||||
|
ranks_in_group = ranks
|
||||||
|
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
|
@ -36,28 +36,4 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
||||||
process_group, ranks_in_group,
|
process_group, ranks_in_group,
|
||||||
ParallelMode.PIPELINE)))
|
ParallelMode.PIPELINE)))
|
||||||
|
|
||||||
for k in range(pipe_group_size):
|
|
||||||
first = pipe_ranks[k]
|
|
||||||
second = pipe_ranks[(k + 1) % pipe_group_size]
|
|
||||||
ranks = [first, second]
|
|
||||||
group = dist.new_group(ranks)
|
|
||||||
if self.rank == first:
|
|
||||||
local_rank = 0
|
|
||||||
group_world_size = 2
|
|
||||||
process_group = group
|
|
||||||
ranks_in_group = ranks
|
|
||||||
dist_settings.append(
|
|
||||||
tuple((local_rank, group_world_size,
|
|
||||||
process_group, ranks_in_group,
|
|
||||||
ParallelMode.PIPELINE_NEXT)))
|
|
||||||
elif self.rank == second:
|
|
||||||
local_rank = 1
|
|
||||||
group_world_size = 2
|
|
||||||
process_group = group
|
|
||||||
ranks_in_group = ranks
|
|
||||||
dist_settings.append(
|
|
||||||
tuple((local_rank, group_world_size,
|
|
||||||
process_group, ranks_in_group,
|
|
||||||
ParallelMode.PIPELINE_PREV)))
|
|
||||||
|
|
||||||
return dist_settings
|
return dist_settings
|
||||||
|
|
|
@ -2,15 +2,12 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.builder import build_gradient_handler
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import is_using_ddp, is_using_pp
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
from ._data_parallel_gradient_handler import DataParallelGradientHandler
|
||||||
from ._zero_gradient_handler import ZeROGradientHandler
|
from ._zero_gradient_handler import ZeROGradientHandler
|
||||||
|
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
|
||||||
|
|
||||||
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler']
|
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
|
||||||
|
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler']
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
@GRADIENT_HANDLER.register_module
|
||||||
|
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||||
|
"""A helper class to handle all-reduce operations in sub parallel groups.
|
||||||
|
A all-reduce collective communication will be operated in
|
||||||
|
:func:`handle_gradient` among all sub pipeline parallel groups.
|
||||||
|
For better performance, it bucketizes the gradients of all parameters that are
|
||||||
|
the same type to improve the efficiency of communication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def handle_gradient(self):
|
||||||
|
"""A method running a all-reduce operation in sub pipeline parallel groups.
|
||||||
|
"""
|
||||||
|
if gpc.pipeline_parallel_size > 1:
|
||||||
|
# bucketize and all-reduce
|
||||||
|
buckets = defaultdict(lambda: defaultdict(list))
|
||||||
|
# Pack the buckets.
|
||||||
|
for param in self._model.parameters():
|
||||||
|
group = getattr(param, 'pipeline_shared_module_pg', None)
|
||||||
|
if param.requires_grad and param.grad is not None and group is not None:
|
||||||
|
tp = param.data.type()
|
||||||
|
buckets[group][tp].append(param)
|
||||||
|
|
||||||
|
# For each bucket, all-reduce and copy all-reduced grads.
|
||||||
|
for group, group_buckets in buckets.items():
|
||||||
|
for tp, bucket in group_buckets.items():
|
||||||
|
grads = [param.grad.data for param in bucket]
|
||||||
|
coalesced = _flatten_dense_tensors(grads)
|
||||||
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
||||||
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
|
buf.copy_(synced)
|
|
@ -5,8 +5,7 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch import Tensor
|
from typing import Iterable, Callable
|
||||||
from typing import Iterable, Union, List, Callable
|
|
||||||
from .._base_engine import Engine
|
from .._base_engine import Engine
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -32,18 +31,17 @@ class BaseSchedule(ABC):
|
||||||
return element
|
return element
|
||||||
|
|
||||||
def _move_to_device(self, data):
|
def _move_to_device(self, data):
|
||||||
if isinstance(data, (tuple, list)):
|
if isinstance(data, dict):
|
||||||
data = tuple([self._move_tensor(d) for d in data])
|
data = {k: self._move_tensor(v) for k, v in data.items()}
|
||||||
elif torch.is_tensor(data):
|
else:
|
||||||
data = data.to(get_current_device()).detach()
|
data = self._move_tensor(data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _to_list(self, data):
|
@staticmethod
|
||||||
if torch.is_tensor(data):
|
def _check_sanity(data, tag):
|
||||||
return [data]
|
assert isinstance(data, (torch.Tensor, dict)), f'{tag} must be torch.Tensor or dict'
|
||||||
return data
|
|
||||||
|
|
||||||
def load_batch(self, data_iter):
|
def load_batch(self, data_iter, to_gpu=True):
|
||||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||||
already in the same GPU as where the model's.
|
already in the same GPU as where the model's.
|
||||||
|
|
||||||
|
@ -58,13 +56,17 @@ class BaseSchedule(ABC):
|
||||||
data, label = self.batch_data_process_func(batch_data)
|
data, label = self.batch_data_process_func(batch_data)
|
||||||
else:
|
else:
|
||||||
data, label = batch_data
|
data, label = batch_data
|
||||||
|
self._check_sanity(data, 'data')
|
||||||
if isinstance(label, (tuple, list)):
|
self._check_sanity(label, 'label')
|
||||||
self.batch_size = label[0].size(0)
|
if isinstance(data, torch.Tensor):
|
||||||
|
self.batch_size = data.size(0)
|
||||||
else:
|
else:
|
||||||
self.batch_size = label.size(0)
|
self.batch_size = next(iter(data.values())).size(0)
|
||||||
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
|
data, label = split_batch(data), split_batch(label)
|
||||||
|
if to_gpu:
|
||||||
return self._move_to_device(data), self._move_to_device(label)
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
|
return data, label
|
||||||
|
|
||||||
|
|
||||||
def pre_processing(self, engine: Engine):
|
def pre_processing(self, engine: Engine):
|
||||||
"""To perform actions before running the schedule.
|
"""To perform actions before running the schedule.
|
||||||
|
@ -76,7 +78,8 @@ class BaseSchedule(ABC):
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
data_iter: Iterable,
|
data_iter: Iterable,
|
||||||
forward_only: bool,
|
forward_only: bool,
|
||||||
return_loss: bool = True
|
return_loss: bool = True,
|
||||||
|
return_output_label: bool = True
|
||||||
):
|
):
|
||||||
"""The process function over a batch of dataset for training or evaluation.
|
"""The process function over a batch of dataset for training or evaluation.
|
||||||
|
|
||||||
|
@ -85,5 +88,24 @@ class BaseSchedule(ABC):
|
||||||
:param labels: ground truth
|
:param labels: ground truth
|
||||||
:param forward_only: If True, the process won't include backward
|
:param forward_only: If True, the process won't include backward
|
||||||
:param return_loss: If False, the loss won't be returned
|
:param return_loss: If False, the loss won't be returned
|
||||||
|
:param return_output_label: If False, the output and label won't be returned
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _call_engine(engine, inputs):
|
||||||
|
if isinstance(inputs, torch.Tensor):
|
||||||
|
return engine(inputs)
|
||||||
|
else:
|
||||||
|
return engine(**inputs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _call_engine_criterion(engine, outputs, labels):
|
||||||
|
assert isinstance(outputs, (torch.Tensor, list, tuple)
|
||||||
|
), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
outputs = (outputs, )
|
||||||
|
if isinstance(labels, torch.Tensor):
|
||||||
|
return engine.criterion(*outputs, labels)
|
||||||
|
else:
|
||||||
|
return engine.criterion(*outputs, **labels)
|
||||||
|
|
|
@ -5,9 +5,7 @@ from typing import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
from torch.optim import Optimizer
|
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
from colossalai.utils import conditional_context
|
from colossalai.utils import conditional_context
|
||||||
|
|
||||||
|
@ -27,17 +25,20 @@ class NonPipelineSchedule(BaseSchedule):
|
||||||
engine: Engine,
|
engine: Engine,
|
||||||
data_iter: Iterable,
|
data_iter: Iterable,
|
||||||
forward_only: bool = False,
|
forward_only: bool = False,
|
||||||
return_loss: bool = True):
|
return_loss: bool = True,
|
||||||
|
return_output_label: bool = True):
|
||||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||||
:param engine: Model for training and inference
|
:param engine: Model for training and inference
|
||||||
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
||||||
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
||||||
:param return_loss: Loss will be returned if True
|
:param return_loss: Loss will be returned if True
|
||||||
|
:param return_output_label: Output and label will be returned if True
|
||||||
:type engine: Iterator
|
:type engine: Iterator
|
||||||
:type data_iter: Iterator
|
:type data_iter: Iterator
|
||||||
:type forward_only: bool, optional
|
:type forward_only: bool, optional
|
||||||
:type return_loss: bool, optional
|
:type return_loss: bool, optional
|
||||||
|
:type return_output_label: bool, optional
|
||||||
|
|
||||||
:return: (output, label, loss)
|
:return: (output, label, loss)
|
||||||
:rtype: Tuple[:class:`torch.Tensor`]
|
:rtype: Tuple[:class:`torch.Tensor`]
|
||||||
|
@ -48,16 +49,20 @@ class NonPipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||||
output = engine(*data)
|
output = self._call_engine(engine, data)
|
||||||
if not isinstance(output, (tuple, list)):
|
|
||||||
output = (output,)
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss = engine.criterion(*output, *label)
|
loss = self._call_engine_criterion(engine, output, label)
|
||||||
|
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
|
|
||||||
|
if return_output_label:
|
||||||
if return_loss:
|
if return_loss:
|
||||||
return output, label, loss
|
return output, label, loss
|
||||||
else:
|
else:
|
||||||
return output, None, None
|
return output, label, None
|
||||||
|
else:
|
||||||
|
if return_loss:
|
||||||
|
return None, None, loss
|
||||||
|
else:
|
||||||
|
return None, None, None
|
||||||
|
|
|
@ -1,19 +1,19 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Union
|
from typing import List, Tuple, Union, Callable
|
||||||
|
import inspect
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import torch.distributed as dist
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from colossalai.communication import *
|
from colossalai.communication import *
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)
|
ZeroRedundancyOptimizer_Level_3)
|
||||||
from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,102 +30,79 @@ class PipelineSchedule(BaseSchedule):
|
||||||
:class:`NonPipelineSchedule`.
|
:class:`NonPipelineSchedule`.
|
||||||
|
|
||||||
:param num_microbatches: The number of microbatches
|
:param num_microbatches: The number of microbatches
|
||||||
:param amp_type: The type of automatic mixed precision
|
|
||||||
:param amp_config: The configuration of automatic mixed procision
|
|
||||||
:param sync_data: If set to `True`, will sync data every batch over pipeline stages
|
|
||||||
:type num_microbatches: int
|
:type num_microbatches: int
|
||||||
:type amp_type: AMP_TYPE
|
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||||
:type amp_config: dict
|
:type batch_data_process_func: Callable
|
||||||
:type sync_data: bool
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_microbatches,
|
num_microbatches,
|
||||||
sync_data: bool = True):
|
batch_data_process_func: Callable = None,
|
||||||
super().__init__()
|
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
||||||
|
super().__init__(batch_data_process_func=batch_data_process_func)
|
||||||
self.num_microbatches = num_microbatches
|
self.num_microbatches = num_microbatches
|
||||||
self.sync_data = sync_data
|
|
||||||
self.dtype = torch.float
|
self.dtype = torch.float
|
||||||
|
self.tensor_shape = tensor_shape
|
||||||
|
|
||||||
def _move_to_device(self, data):
|
|
||||||
if isinstance(data, (
|
|
||||||
tuple,
|
|
||||||
list,
|
|
||||||
)):
|
|
||||||
assert len(data) == 1, "Data tuple's length in pipeline should be 1"
|
|
||||||
data = data[0]
|
|
||||||
assert torch.is_tensor(data), "Data in pipeline should be tensor"
|
|
||||||
data = data.to(get_current_device()).detach()
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _sync_data(self):
|
|
||||||
reqs = []
|
|
||||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
|
||||||
src_rank = gpc.get_global_rank()
|
|
||||||
reqs.append(dist.broadcast(
|
|
||||||
tensor=self.batch_data,
|
|
||||||
src=src_rank,
|
|
||||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
|
|
||||||
async_op=True
|
|
||||||
))
|
|
||||||
reqs.append(dist.broadcast(
|
|
||||||
tensor=self.batch_label,
|
|
||||||
src=src_rank,
|
|
||||||
group=gpc.get_group(ParallelMode.PIPELINE_PREV),
|
|
||||||
async_op=True
|
|
||||||
))
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
||||||
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
|
||||||
reqs.append(dist.broadcast(
|
|
||||||
tensor=self.batch_data,
|
|
||||||
src=src_rank,
|
|
||||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
|
|
||||||
async_op=True
|
|
||||||
))
|
|
||||||
reqs.append(dist.broadcast(
|
|
||||||
tensor=self.batch_label,
|
|
||||||
src=src_rank,
|
|
||||||
group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
|
|
||||||
async_op=True
|
|
||||||
))
|
|
||||||
for req in reqs:
|
|
||||||
req.wait()
|
|
||||||
|
|
||||||
# Pipeline schedule just puts data in memory
|
|
||||||
def load_batch(self, data_iter):
|
def load_batch(self, data_iter):
|
||||||
if data_iter is None:
|
# Pipeline schedule just puts data in memory
|
||||||
raise RuntimeError('Dataloader is not defined.')
|
self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
|
||||||
self.batch_pos = 0
|
self.microbatch_offset = 0
|
||||||
data, label = next(data_iter)
|
assert self.batch_size % self.num_microbatches == 0, \
|
||||||
self.batch_data, self.batch_label = \
|
|
||||||
self._move_to_device(data), self._move_to_device(label)
|
|
||||||
batch_size = self.batch_data.shape[0]
|
|
||||||
assert batch_size % self.num_microbatches == 0, \
|
|
||||||
"Batch size should divided by the number of microbatches"
|
"Batch size should divided by the number of microbatches"
|
||||||
self.microbatch_size = batch_size // self.num_microbatches
|
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||||
if self.sync_data:
|
|
||||||
self._sync_data()
|
|
||||||
|
|
||||||
def _get_data_slice(self, tensor):
|
def _get_data_slice(self, data, offset):
|
||||||
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size]
|
if isinstance(data, torch.Tensor):
|
||||||
|
return data[offset: offset + self.microbatch_size]
|
||||||
|
else:
|
||||||
|
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
|
||||||
|
|
||||||
def load_micro_batch(self):
|
def load_micro_batch(self):
|
||||||
data = self._get_data_slice(self.batch_data)
|
data = self._get_data_slice(self.batch_data, self.microbatch_offset)
|
||||||
label = self._get_data_slice(self.batch_label)
|
label = self._get_data_slice(self.batch_label, self.microbatch_offset)
|
||||||
self.batch_pos += self.microbatch_size
|
self.microbatch_offset += self.microbatch_size
|
||||||
return (data,), (label,)
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||||
)
|
)
|
||||||
|
model = engine.model
|
||||||
if isinstance(engine.model, NaiveAMPModel):
|
if isinstance(model, NaiveAMPModel):
|
||||||
self.dtype = torch.half
|
self.dtype = torch.half
|
||||||
|
model = model.model
|
||||||
|
sig = inspect.signature(model.forward)
|
||||||
|
for p in sig.parameters.values():
|
||||||
|
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||||
|
|
||||||
def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
|
@staticmethod
|
||||||
|
def _call_engine(model, input_tensor, batch_data):
|
||||||
|
if isinstance(model, NaiveAMPModel):
|
||||||
|
sig = inspect.signature(model.model.forward)
|
||||||
|
else:
|
||||||
|
sig = inspect.signature(model.forward)
|
||||||
|
if isinstance(batch_data, torch.Tensor):
|
||||||
|
if input_tensor is None:
|
||||||
|
return model(batch_data)
|
||||||
|
elif len(sig.parameters) > 1:
|
||||||
|
return model(input_tensor, batch_data)
|
||||||
|
else:
|
||||||
|
return model(input_tensor)
|
||||||
|
else:
|
||||||
|
filter_batch = True
|
||||||
|
for p in sig.parameters.values():
|
||||||
|
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
||||||
|
filter_batch = False
|
||||||
|
if filter_batch:
|
||||||
|
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
|
||||||
|
if input_tensor is None:
|
||||||
|
return model(**batch_data)
|
||||||
|
else:
|
||||||
|
return model(input_tensor, **batch_data)
|
||||||
|
|
||||||
|
def forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
|
||||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
@ -140,26 +117,19 @@ class PipelineSchedule(BaseSchedule):
|
||||||
:return: output or the loss value of the current pipeline stage
|
:return: output or the loss value of the current pipeline stage
|
||||||
:rtype: :class:`torch.Tensor`
|
:rtype: :class:`torch.Tensor`
|
||||||
"""
|
"""
|
||||||
|
data, label = self.load_micro_batch()
|
||||||
if input_tensor is None:
|
output_tensor = self._call_engine(engine.model, input_tensor, data)
|
||||||
input_tensor, label = self.load_micro_batch()
|
|
||||||
input_tensor = squeeze(input_tensor)
|
|
||||||
output_tensor = engine(input_tensor)
|
|
||||||
output_tensor = squeeze(output_tensor)
|
output_tensor = squeeze(output_tensor)
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
if return_loss:
|
if return_output_label:
|
||||||
input_tensor, label = self.load_micro_batch()
|
return_tensors.append(tuple((output_tensor, label)))
|
||||||
loss_reduced = engine.criterion(output_tensor, *label) \
|
if accum_loss is not None:
|
||||||
/ self.num_microbatches
|
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
||||||
|
accum_loss.add_(loss_reduced.detach())
|
||||||
return_tensors.append(
|
|
||||||
tuple((output_tensor, label[0], loss_reduced)))
|
|
||||||
return loss_reduced
|
return loss_reduced
|
||||||
else:
|
else:
|
||||||
return_tensors.append(output_tensor)
|
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
@ -203,7 +173,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
engine,
|
engine,
|
||||||
data_iter,
|
data_iter,
|
||||||
forward_only=False,
|
forward_only=False,
|
||||||
return_loss=True):
|
return_loss=True,
|
||||||
|
return_output_label=True):
|
||||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||||
|
|
||||||
|
@ -215,6 +186,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
:type forward_only: bool
|
:type forward_only: bool
|
||||||
:param return_loss: whether returns the loss value. Default is true.
|
:param return_loss: whether returns the loss value. Default is true.
|
||||||
:type return_loss: bool
|
:type return_loss: bool
|
||||||
|
:param return_output_label: If False, the output and label won't be returned
|
||||||
|
:type return_output_label: bool
|
||||||
|
|
||||||
:return: (output, label, loss)
|
:return: (output, label, loss)
|
||||||
:rtype: Tuple[:class:`torch.Tensor`]
|
:rtype: Tuple[:class:`torch.Tensor`]
|
||||||
|
@ -238,11 +211,14 @@ class PipelineSchedule(BaseSchedule):
|
||||||
input_tensors = []
|
input_tensors = []
|
||||||
output_tensors = []
|
output_tensors = []
|
||||||
return_tensors = []
|
return_tensors = []
|
||||||
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||||
|
accum_loss = torch.zeros(1, device=get_current_device())
|
||||||
|
else:
|
||||||
|
accum_loss = None
|
||||||
# Used for tensor meta information communication
|
# Used for tensor meta information communication
|
||||||
ft_shape = None
|
ft_shape = self.tensor_shape
|
||||||
bt_shape = None
|
bt_shape = None
|
||||||
fs_checker = True
|
fs_checker = self.tensor_shape is None
|
||||||
|
|
||||||
# Run warmup forward passes.
|
# Run warmup forward passes.
|
||||||
for i in range(num_warmup_microbatches):
|
for i in range(num_warmup_microbatches):
|
||||||
|
@ -251,7 +227,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
input_tensor = recv_forward(ft_shape, dtype=self.dtype)
|
||||||
output_tensor = self.forward_step(
|
output_tensor = self.forward_step(
|
||||||
engine, input_tensor, return_tensors,
|
engine, input_tensor, return_tensors,
|
||||||
return_loss=return_loss
|
return_output_label=return_output_label,
|
||||||
|
accum_loss=accum_loss
|
||||||
)
|
)
|
||||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
bt_shape = output_tensor.shape
|
bt_shape = output_tensor.shape
|
||||||
|
@ -276,7 +253,8 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
output_tensor = self.forward_step(
|
output_tensor = self.forward_step(
|
||||||
engine, input_tensor, return_tensors,
|
engine, input_tensor, return_tensors,
|
||||||
return_loss=return_loss
|
return_output_label=return_output_label,
|
||||||
|
accum_loss=accum_loss
|
||||||
)
|
)
|
||||||
if forward_only:
|
if forward_only:
|
||||||
send_forward(output_tensor)
|
send_forward(output_tensor)
|
||||||
|
@ -327,24 +305,37 @@ class PipelineSchedule(BaseSchedule):
|
||||||
send_backward(input_tensor_grad)
|
send_backward(input_tensor_grad)
|
||||||
|
|
||||||
if len(return_tensors) > 0:
|
if len(return_tensors) > 0:
|
||||||
if return_loss:
|
output, label = tuple(map(list, zip(*return_tensors)))
|
||||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
|
||||||
return (torch.cat(output, dim=0),
|
return (torch.cat(output, dim=0),
|
||||||
torch.cat(label, dim=0),
|
torch.cat(label, dim=0),
|
||||||
sum(loss))
|
accum_loss)
|
||||||
else:
|
else:
|
||||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
return tuple((None, None, accum_loss))
|
||||||
else:
|
|
||||||
return tuple((None, None, None))
|
|
||||||
|
|
||||||
|
|
||||||
class InterleavedPipelineSchedule(PipelineSchedule):
|
class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True):
|
def __init__(self,
|
||||||
|
num_microbatches,
|
||||||
|
num_model_chunks,
|
||||||
|
batch_data_process_func: Callable = None,
|
||||||
|
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None):
|
||||||
|
"""A helper schedule class for pipeline parallelism running environment.
|
||||||
|
It uses interleaved 1F1B strategy. Other properties are similar as
|
||||||
|
:class:`NonPipelineSchedule`.
|
||||||
|
|
||||||
|
:param num_microbatches: The number of microbatches
|
||||||
|
:type num_microbatches: int
|
||||||
|
:param num_model_chunks: The number of model chunks
|
||||||
|
:type num_model_chunks: int
|
||||||
|
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
|
||||||
|
:type batch_data_process_func: Callable
|
||||||
|
"""
|
||||||
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
||||||
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
||||||
super().__init__(num_microbatches, sync_data=sync_data)
|
super().__init__(num_microbatches, batch_data_process_func=batch_data_process_func, tensor_shape=tensor_shape)
|
||||||
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||||
|
self.num_model_chunks = num_model_chunks
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||||
|
@ -355,32 +346,46 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
if isinstance(engine.model[0], NaiveAMPModel):
|
if isinstance(engine.model[0], NaiveAMPModel):
|
||||||
self.dtype = torch.half
|
self.dtype = torch.half
|
||||||
|
|
||||||
def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True):
|
for model in engine.model:
|
||||||
|
if isinstance(model, NaiveAMPModel):
|
||||||
|
model = model.model
|
||||||
|
sig = inspect.signature(model.forward)
|
||||||
|
for p in sig.parameters.values():
|
||||||
|
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||||
|
|
||||||
|
def load_batch(self, data_iter):
|
||||||
|
super().load_batch(data_iter)
|
||||||
|
# overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
|
||||||
|
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
||||||
|
|
||||||
|
def load_micro_batch(self, model_chunk_id):
|
||||||
|
data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id])
|
||||||
|
label = self._get_data_slice(self.batch_label, self.microbatch_offset[model_chunk_id])
|
||||||
|
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||||
|
return self._move_to_device(data), self._move_to_device(label)
|
||||||
|
|
||||||
|
def forward_step(self, engine, model_chunk_id, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
|
||||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
"""
|
"""
|
||||||
|
data, label = self.load_micro_batch(model_chunk_id)
|
||||||
if input_tensor is None:
|
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
|
||||||
input_tensor, label = self.load_micro_batch()
|
|
||||||
input_tensor = squeeze(input_tensor)
|
|
||||||
output_tensor = model(input_tensor)
|
|
||||||
output_tensor = squeeze(output_tensor)
|
output_tensor = squeeze(output_tensor)
|
||||||
|
|
||||||
if gpc.is_pipeline_last_stage():
|
if gpc.is_pipeline_last_stage():
|
||||||
if return_loss:
|
if return_output_label:
|
||||||
input_tensor, label = self.load_micro_batch()
|
return_tensors.append(tuple(output_tensor, label))
|
||||||
loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches
|
if accum_loss is not None:
|
||||||
return_tensors.append(
|
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
||||||
tuple((output_tensor, label[0], loss_reduced)))
|
accum_loss.add_(loss_reduced.detach())
|
||||||
return loss_reduced
|
return loss_reduced
|
||||||
else:
|
else:
|
||||||
return_tensors.append(output_tensor)
|
|
||||||
return output_tensor
|
return output_tensor
|
||||||
else:
|
else:
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True):
|
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||||
communication between pipeline stages as needed.
|
communication between pipeline stages as needed.
|
||||||
|
|
||||||
|
@ -394,11 +399,15 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
return_tensors = []
|
return_tensors = []
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
output_tensor_grads = [[] for _ in range(len(model))]
|
output_tensor_grads = [[] for _ in range(len(model))]
|
||||||
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||||
|
accum_loss = torch.zeros(1, device=get_current_device())
|
||||||
|
else:
|
||||||
|
accum_loss = None
|
||||||
|
|
||||||
# Used for tensor meta information communication
|
# Used for tensor meta information communication
|
||||||
input_tensor_shapes = [None for _ in range(len(model))]
|
input_tensor_shapes = [self.tensor_shape for _ in range(len(model))]
|
||||||
output_tensor_shapes = [None for _ in range(len(model))]
|
output_tensor_shapes = [None for _ in range(len(model))]
|
||||||
send_tensor_shape_flags = [True for _ in range(len(model))]
|
send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
|
||||||
|
|
||||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
@ -450,8 +459,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
len(output_tensors[model_chunk_id]):
|
len(output_tensors[model_chunk_id]):
|
||||||
input_tensors[model_chunk_id].append(None)
|
input_tensors[model_chunk_id].append(None)
|
||||||
input_tensor = input_tensors[model_chunk_id][-1]
|
input_tensor = input_tensors[model_chunk_id][-1]
|
||||||
output_tensor = self.forward_step(
|
output_tensor = self.forward_step(engine, model_chunk_id, input_tensor,
|
||||||
engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss)
|
return_tensors, return_output_label=return_output_label, accum_loss=accum_loss)
|
||||||
output_tensors[model_chunk_id].append(output_tensor)
|
output_tensors[model_chunk_id].append(output_tensor)
|
||||||
|
|
||||||
# if forward-only, no need to save tensors for a backward pass
|
# if forward-only, no need to save tensors for a backward pass
|
||||||
|
@ -633,12 +642,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
dtype=self.dtype))
|
dtype=self.dtype))
|
||||||
|
|
||||||
if len(return_tensors) > 0:
|
if len(return_tensors) > 0:
|
||||||
if return_loss:
|
output, label = tuple(map(list, zip(*return_tensors)))
|
||||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
|
||||||
return (torch.cat(output, dim=0),
|
return (torch.cat(output, dim=0),
|
||||||
torch.cat(label, dim=0),
|
torch.cat(label, dim=0),
|
||||||
sum(loss))
|
accum_loss)
|
||||||
else:
|
else:
|
||||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
return tuple((None, None, accum_loss))
|
||||||
else:
|
|
||||||
return tuple((None, None, None))
|
|
||||||
|
|
|
@ -338,6 +338,19 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||||
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
|
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
|
||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
|
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
||||||
|
for param in model.parameters():
|
||||||
|
if getattr(param, 'pipeline_shared_module_pg', None) is not None:
|
||||||
|
if gradient_handler_cfg is None:
|
||||||
|
gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
|
||||||
|
else:
|
||||||
|
gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
|
||||||
|
if verbose:
|
||||||
|
logger.info(
|
||||||
|
"pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
|
||||||
|
"added even though not specified in the configuration",
|
||||||
|
ranks=[0])
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
if not isinstance(gradient_handler_cfg, list):
|
if not isinstance(gradient_handler_cfg, list):
|
||||||
raise ConfigException(
|
raise ConfigException(
|
||||||
|
|
|
@ -11,8 +11,8 @@ _parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d':
|
||||||
def split_batch(input_) -> Tensor:
|
def split_batch(input_) -> Tensor:
|
||||||
tensor_parallel_mode = get_tensor_parallel_mode()
|
tensor_parallel_mode = get_tensor_parallel_mode()
|
||||||
if tensor_parallel_mode in _parallel_split_batch:
|
if tensor_parallel_mode in _parallel_split_batch:
|
||||||
if isinstance(input_, (tuple, list)):
|
if isinstance(input_, dict):
|
||||||
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
|
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
|
||||||
else:
|
else:
|
||||||
return _parallel_split_batch[tensor_parallel_mode](input_)
|
return _parallel_split_batch[tensor_parallel_mode](input_)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
from .lambda_wrapper import LambdaWrapper
|
from .lambda_wrapper import LambdaWrapper
|
||||||
|
from .pipeline_wrapper import PipelineSharedModuleWrapper
|
||||||
|
|
||||||
__all__ = ['LambdaWrapper']
|
__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper']
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineSharedModuleWrapper:
|
||||||
|
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
|
||||||
|
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
|
||||||
|
self.pipeline_ranks = pipeline_ranks
|
||||||
|
self.group = None
|
||||||
|
self.ranks_in_group = None
|
||||||
|
self._init_group()
|
||||||
|
|
||||||
|
def _init_group(self):
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||||
|
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
|
rank = gpc.get_global_rank()
|
||||||
|
num_dp_groups = world_size // dp_size
|
||||||
|
num_pp_stages = num_dp_groups // pp_size
|
||||||
|
for i in range(dp_size):
|
||||||
|
for j in range(num_pp_stages):
|
||||||
|
pipeline_ranks = list(
|
||||||
|
range(i * num_dp_groups + j,
|
||||||
|
(i + 1) * num_dp_groups,
|
||||||
|
num_pp_stages))
|
||||||
|
sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks]
|
||||||
|
group = dist.new_group(sub_ranks)
|
||||||
|
if rank in sub_ranks:
|
||||||
|
self.group = group
|
||||||
|
self.ranks_in_group = sub_ranks
|
||||||
|
|
||||||
|
def register_module(self, module: nn.Module):
|
||||||
|
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||||
|
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||||
|
for p in module.parameters():
|
||||||
|
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||||
|
dist.broadcast(p, src, group=self.group)
|
|
@ -155,7 +155,8 @@ class Trainer:
|
||||||
def _train_epoch(self,
|
def _train_epoch(self,
|
||||||
train_dataloader: DataLoader,
|
train_dataloader: DataLoader,
|
||||||
epoch: int = None,
|
epoch: int = None,
|
||||||
display_progress: bool = False):
|
display_progress: bool = False,
|
||||||
|
return_output_label: bool = True):
|
||||||
# set training state
|
# set training state
|
||||||
self._engine.train()
|
self._engine.train()
|
||||||
data_iter = iter(train_dataloader)
|
data_iter = iter(train_dataloader)
|
||||||
|
@ -175,7 +176,7 @@ class Trainer:
|
||||||
# run 1 training step
|
# run 1 training step
|
||||||
self.engine.zero_grad()
|
self.engine.zero_grad()
|
||||||
logits, label, loss = self.schedule.forward_backward_step(
|
logits, label, loss = self.schedule.forward_backward_step(
|
||||||
self.engine, data_iter, forward_only=False, return_loss=True)
|
self.engine, data_iter, forward_only=False, return_loss=True, return_output_label=return_output_label)
|
||||||
self.engine.step()
|
self.engine.step()
|
||||||
self._call_timer(action='stop', item='Train-step', keep_in_history=True)
|
self._call_timer(action='stop', item='Train-step', keep_in_history=True)
|
||||||
self._call_hooks('after_train_iter', output=(logits, label, loss))
|
self._call_hooks('after_train_iter', output=(logits, label, loss))
|
||||||
|
@ -197,7 +198,8 @@ class Trainer:
|
||||||
def _eval(self,
|
def _eval(self,
|
||||||
test_dataloader: DataLoader,
|
test_dataloader: DataLoader,
|
||||||
epoch: int = None,
|
epoch: int = None,
|
||||||
display_progress: bool = False):
|
display_progress: bool = False,
|
||||||
|
return_output_label: bool = True):
|
||||||
# switch engine status
|
# switch engine status
|
||||||
self._engine.eval()
|
self._engine.eval()
|
||||||
|
|
||||||
|
@ -220,7 +222,7 @@ class Trainer:
|
||||||
self._call_hooks('before_test_iter')
|
self._call_hooks('before_test_iter')
|
||||||
self._call_timer(action='start', item='Test-step')
|
self._call_timer(action='start', item='Test-step')
|
||||||
logits, label, loss = self.schedule.forward_backward_step(
|
logits, label, loss = self.schedule.forward_backward_step(
|
||||||
self.engine, data_iter, forward_only=True, return_loss=True)
|
self.engine, data_iter, forward_only=True, return_loss=True, return_output_label=return_output_label)
|
||||||
self._call_timer(action='stop', item='Test-step', keep_in_history=True)
|
self._call_timer(action='stop', item='Test-step', keep_in_history=True)
|
||||||
self._call_hooks('after_test_iter',
|
self._call_hooks('after_test_iter',
|
||||||
output=(logits, label, loss))
|
output=(logits, label, loss))
|
||||||
|
@ -246,6 +248,7 @@ class Trainer:
|
||||||
test_interval: int = 1,
|
test_interval: int = 1,
|
||||||
hooks: List[BaseHook] = None,
|
hooks: List[BaseHook] = None,
|
||||||
display_progress: bool = False,
|
display_progress: bool = False,
|
||||||
|
return_output_label: bool = True,
|
||||||
):
|
):
|
||||||
"""Trains the model to fit training data.
|
"""Trains the model to fit training data.
|
||||||
|
|
||||||
|
@ -256,6 +259,8 @@ class Trainer:
|
||||||
:param test_interval: Interval of testing
|
:param test_interval: Interval of testing
|
||||||
:param hooks_cfg: A list of hook configuration
|
:param hooks_cfg: A list of hook configuration
|
||||||
:param display_progress: If True, the training progress will be printed
|
:param display_progress: If True, the training progress will be printed
|
||||||
|
:param return_output_label: If True, the output of model and the label will be returned
|
||||||
|
:type return_output_label: bool
|
||||||
:type train_dataloader: DataLoader
|
:type train_dataloader: DataLoader
|
||||||
:type epochs: int
|
:type epochs: int
|
||||||
:type max_steps: int
|
:type max_steps: int
|
||||||
|
@ -307,7 +312,8 @@ class Trainer:
|
||||||
self._train_epoch(
|
self._train_epoch(
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
display_progress=display_progress
|
display_progress=display_progress,
|
||||||
|
return_output_label=return_output_label
|
||||||
)
|
)
|
||||||
|
|
||||||
# start eval
|
# start eval
|
||||||
|
@ -315,6 +321,7 @@ class Trainer:
|
||||||
self._eval(test_dataloader=test_dataloader,
|
self._eval(test_dataloader=test_dataloader,
|
||||||
display_progress=display_progress,
|
display_progress=display_progress,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
|
return_output_label=return_output_label
|
||||||
)
|
)
|
||||||
|
|
||||||
self._cur_epoch += 1
|
self._cur_epoch += 1
|
||||||
|
@ -331,13 +338,16 @@ class Trainer:
|
||||||
def evaluate(self,
|
def evaluate(self,
|
||||||
test_dataloader: DataLoader,
|
test_dataloader: DataLoader,
|
||||||
hooks: List[BaseHook] = None,
|
hooks: List[BaseHook] = None,
|
||||||
display_progress: bool = False):
|
display_progress: bool = False,
|
||||||
|
return_output_label: bool = True):
|
||||||
"""Evaluates the model with testing data.
|
"""Evaluates the model with testing data.
|
||||||
|
|
||||||
:param test_dataloader: DataLoader in testing
|
:param test_dataloader: DataLoader in testing
|
||||||
:param display_progress: If True, the evaluation progress will be printed
|
:param display_progress: If True, the evaluation progress will be printed
|
||||||
|
:param return_output_label: If True, the output of model and the label will be returned
|
||||||
:type test_dataloader: DataLoader
|
:type test_dataloader: DataLoader
|
||||||
:type display_progress: bool, optional
|
:type display_progress: bool, optional
|
||||||
|
:type return_output_label: bool
|
||||||
"""
|
"""
|
||||||
# set display
|
# set display
|
||||||
display_progress = self._should_display_progress(display_progress)
|
display_progress = self._should_display_progress(display_progress)
|
||||||
|
@ -360,6 +370,7 @@ class Trainer:
|
||||||
# eval
|
# eval
|
||||||
self._eval(test_dataloader=test_dataloader,
|
self._eval(test_dataloader=test_dataloader,
|
||||||
display_progress=display_progress,
|
display_progress=display_progress,
|
||||||
|
return_output_label=return_output_label
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
def predict(self, data: Union[Tensor, List[Tensor]]):
|
||||||
|
|
|
@ -155,22 +155,12 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||||
if norm_type == inf:
|
if norm_type == inf:
|
||||||
total_norm = max(p.grad.data.abs().max() for p in params)
|
total_norm = max(p.grad.data.abs().max() for p in params)
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
ops = []
|
|
||||||
# Take max across all model-parallel GPUs.
|
# Take max across all model-parallel GPUs.
|
||||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||||
ops.append(dist.all_reduce(total_norm_cuda,
|
dist.all_reduce(total_norm_cuda,
|
||||||
op=dist.ReduceOp.MAX,
|
op=dist.ReduceOp.MAX,
|
||||||
group=gpc.get_group(
|
group=gpc.get_group(ParallelMode.MODEL),
|
||||||
ParallelMode.TENSOR),
|
async_op=False)
|
||||||
async_op=True))
|
|
||||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
|
||||||
ops.append(dist.all_reduce(total_norm_cuda,
|
|
||||||
op=dist.ReduceOp.MAX,
|
|
||||||
group=gpc.get_group(
|
|
||||||
ParallelMode.PIPELINE),
|
|
||||||
async_op=True))
|
|
||||||
for req in ops:
|
|
||||||
req.wait()
|
|
||||||
total_norm = total_norm_cuda[0].item()
|
total_norm = total_norm_cuda[0].item()
|
||||||
else:
|
else:
|
||||||
tensor_parallel_grads = []
|
tensor_parallel_grads = []
|
||||||
|
|
|
@ -65,6 +65,7 @@ class GradAccumOptimizer(ColossalaiOptimizer):
|
||||||
self.optim.backward(scaled_loss)
|
self.optim.backward(scaled_loss)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||||
|
self.accumulate_step += 1
|
||||||
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
|
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
|
||||||
|
|
||||||
if no_sync:
|
if no_sync:
|
||||||
|
|
|
@ -26,8 +26,6 @@ follow the steps below to create a new distributed initialization.
|
||||||
GLOBAL = 'global'
|
GLOBAL = 'global'
|
||||||
DATA = 'data'
|
DATA = 'data'
|
||||||
PIPELINE = 'pipe'
|
PIPELINE = 'pipe'
|
||||||
PIPELINE_PREV = 'pipe_prev'
|
|
||||||
PIPELINE_NEXT = 'pipe_next'
|
|
||||||
...
|
...
|
||||||
|
|
||||||
NEW_MODE = 'new_mode' # define your mode here
|
NEW_MODE = 'new_mode' # define your mode here
|
||||||
|
|
|
@ -18,8 +18,6 @@ class ParallelMode(Enum):
|
||||||
GLOBAL = 'global'
|
GLOBAL = 'global'
|
||||||
DATA = 'data'
|
DATA = 'data'
|
||||||
PIPELINE = 'pipe'
|
PIPELINE = 'pipe'
|
||||||
PIPELINE_PREV = 'pipe_prev'
|
|
||||||
PIPELINE_NEXT = 'pipe_next'
|
|
||||||
...
|
...
|
||||||
|
|
||||||
NEW_MODE = 'new_mode' # define your mode here
|
NEW_MODE = 'new_mode' # define your mode here
|
||||||
|
|
|
@ -33,6 +33,12 @@ def check_pipeline_parallel_rank(rank):
|
||||||
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
|
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_parallel_rank(rank):
|
||||||
|
for i in range(8):
|
||||||
|
if rank in [i, i+8]:
|
||||||
|
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_parallel_rank(rank):
|
def check_tensor_parallel_rank(rank):
|
||||||
if rank in [0, 4, 8, 12]:
|
if rank in [0, 4, 8, 12]:
|
||||||
assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||||
|
@ -75,6 +81,7 @@ def init_2d(rank, world_size, backend, port, host):
|
||||||
check_data_parallel_rank(rank)
|
check_data_parallel_rank(rank)
|
||||||
check_2d_parallel_rank(rank)
|
check_2d_parallel_rank(rank)
|
||||||
check_pipeline_parallel_rank(rank)
|
check_pipeline_parallel_rank(rank)
|
||||||
|
check_model_parallel_rank(rank)
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
|
||||||
assert ppr == 1
|
assert ppr == 1
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_parallel_rank(rank):
|
||||||
|
for i in range(16):
|
||||||
|
if rank in [i, i+16]:
|
||||||
|
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_parallel_rank(rank):
|
def check_tensor_parallel_rank(rank):
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
|
|
||||||
|
@ -98,6 +104,7 @@ def init_2halfd(rank, world_size, backend, port, host):
|
||||||
check_pipeline_parallel_rank(rank)
|
check_pipeline_parallel_rank(rank)
|
||||||
check_tensor_parallel_rank(rank)
|
check_tensor_parallel_rank(rank)
|
||||||
check_2p5d_parallel_rank(rank)
|
check_2p5d_parallel_rank(rank)
|
||||||
|
check_model_parallel_rank(rank)
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,12 @@ def check_pipeline_parallel_rank(rank):
|
||||||
assert ppr == 1
|
assert ppr == 1
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_parallel_rank(rank):
|
||||||
|
for i in range(16):
|
||||||
|
if rank in [i, i+16]:
|
||||||
|
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_parallel_rank(rank):
|
def check_tensor_parallel_rank(rank):
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
|
|
||||||
|
@ -90,6 +96,7 @@ def init_3d(rank, world_size, backend, port, host):
|
||||||
check_3d_parallel_rank(rank)
|
check_3d_parallel_rank(rank)
|
||||||
check_data_parallel_rank(rank)
|
check_data_parallel_rank(rank)
|
||||||
check_pipeline_parallel_rank(rank)
|
check_pipeline_parallel_rank(rank)
|
||||||
|
check_model_parallel_rank(rank)
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ BATCH_SIZE = 16
|
||||||
NUM_EPOCHS = 60
|
NUM_EPOCHS = 60
|
||||||
WARMUP_EPOCHS = 5
|
WARMUP_EPOCHS = 5
|
||||||
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||||
fp16=dict(mode=AMP_TYPE.TORCH),
|
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||||
gradient_accumulation=2)
|
gradient_accumulation=2)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -75,40 +75,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
|
||||||
rank, check_equal(grad, output_grad)))
|
rank, check_equal(grad, output_grad)))
|
||||||
|
|
||||||
|
|
||||||
def check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
def check_comm(size, rank, prev_rank, next_rank, logger):
|
||||||
dtype = torch.float32
|
|
||||||
device = get_current_device()
|
|
||||||
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
||||||
# recv_tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
||||||
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
||||||
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
|
|
||||||
dist.all_reduce(tensor)
|
|
||||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
||||||
dist.all_reduce(grad)
|
|
||||||
if rank % 2 == 0:
|
|
||||||
need_meta = True
|
|
||||||
need_meta = send_tensor_meta(tensor, need_meta)
|
|
||||||
logger.info('Rank {} shape sent (need meta: {}).'.format(
|
|
||||||
rank, need_meta))
|
|
||||||
req = dist.broadcast(tensor, src=rank, group=down_group, async_op=True)
|
|
||||||
req.wait()
|
|
||||||
out = tensor.clone()
|
|
||||||
logger.info('Rank {} test op: tensor sent.'.format(rank))
|
|
||||||
else:
|
|
||||||
recv_tensor_shape = recv_tensor_meta(None)
|
|
||||||
logger.info('Rank {} shape received. Correct shape: {}'.format(
|
|
||||||
rank, tensor_shape == recv_tensor_shape))
|
|
||||||
out = torch.empty(recv_tensor_shape, dtype=dtype, device=device)
|
|
||||||
req = dist.broadcast(out, src=prev_rank, group=up_group, async_op=True)
|
|
||||||
req.wait()
|
|
||||||
logger.info('Rank {} test op: received tensor ({})'.format(
|
|
||||||
rank, out.shape))
|
|
||||||
|
|
||||||
logger.info('Rank {} test op. Correct tensor: {}'.format(
|
|
||||||
rank, check_equal(tensor, out)))
|
|
||||||
|
|
||||||
|
|
||||||
def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||||
|
@ -117,7 +84,6 @@ def check_comm(size, rank, prev_rank, next_rank, up_group, down_group, logger):
|
||||||
dist.all_reduce(tensor)
|
dist.all_reduce(tensor)
|
||||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||||
dist.all_reduce(grad)
|
dist.all_reduce(grad)
|
||||||
check_op(size, rank, prev_rank, next_rank, up_group, down_group, logger)
|
|
||||||
check_forward(tensor, rank, logger)
|
check_forward(tensor, rank, logger)
|
||||||
check_backward(grad, rank, logger)
|
check_backward(grad, rank, logger)
|
||||||
check_forward_backward(tensor, grad, rank, logger)
|
check_forward_backward(tensor, grad, rank, logger)
|
||||||
|
@ -135,18 +101,13 @@ def run_check(rank, world_size, port):
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
rank = gpc.get_global_rank()
|
rank = gpc.get_global_rank()
|
||||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||||
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
|
|
||||||
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
|
|
||||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||||
down_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_NEXT)
|
|
||||||
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
'Rank {0}: prev rank {1} (up: {2}), next rank {3} (down: {4})'.format(
|
'Rank {0}: prev rank {1}, next rank {2}'.format(
|
||||||
rank, prev_rank, up_ranks, next_rank, down_ranks))
|
rank, prev_rank, next_rank))
|
||||||
logger.info('Distributed environment is initialzied.')
|
logger.info('Distributed environment is initialzied.')
|
||||||
|
|
||||||
check_comm(world_size, rank, prev_rank, next_rank, up_group, down_group,
|
check_comm(world_size, rank, prev_rank, next_rank, logger)
|
||||||
logger)
|
|
||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue