mirror of https://github.com/hpcaitech/ColossalAI
parent
f03bcb359b
commit
7904baf6e1
|
@ -48,9 +48,13 @@ class PipelineSchedule(BaseSchedule):
|
||||||
# Pipeline schedule just puts data in memory
|
# Pipeline schedule just puts data in memory
|
||||||
self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
|
self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
|
||||||
self.microbatch_offset = 0
|
self.microbatch_offset = 0
|
||||||
assert self.batch_size % self.num_microbatches == 0, \
|
if isinstance(self.batch_data, torch.Tensor):
|
||||||
|
batch_size = self.batch_data.size(0)
|
||||||
|
else:
|
||||||
|
batch_size = next(iter(self.batch_data.values())).size(0)
|
||||||
|
assert batch_size % self.num_microbatches == 0, \
|
||||||
"Batch size should divided by the number of microbatches"
|
"Batch size should divided by the number of microbatches"
|
||||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
self.microbatch_size = batch_size // self.num_microbatches
|
||||||
|
|
||||||
def _get_data_slice(self, data, offset):
|
def _get_data_slice(self, data, offset):
|
||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
|
|
|
@ -71,6 +71,7 @@ class Linear1D(torch.nn.Module):
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class Classifier1D(ParallelLayer):
|
class Classifier1D(ParallelLayer):
|
||||||
"""RowLinear with given weight"""
|
"""RowLinear with given weight"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
|
@ -127,8 +128,8 @@ class Classifier1D(ParallelLayer):
|
||||||
|
|
||||||
output_parallel = F.linear(input_, self.weight)
|
output_parallel = F.linear(input_, self.weight)
|
||||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||||
|
if self.bias is not None:
|
||||||
output = output + self.bias
|
output = output + self.bias
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,6 +153,7 @@ class Linear1D_Col(ParallelLayer):
|
||||||
which is :math:`Y_i = XA_i`, defaults to False
|
which is :math:`Y_i = XA_i`, defaults to False
|
||||||
:type gather_output: bool, optional
|
:type gather_output: bool, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
@ -233,6 +235,7 @@ class Linear1D_Row(ParallelLayer):
|
||||||
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
|
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
|
||||||
:type parallel_input: bool, optional
|
:type parallel_input: bool, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
@ -302,6 +305,7 @@ class Linear1D_Row(ParallelLayer):
|
||||||
class MixedFusedLayerNorm1D(torch.nn.Module):
|
class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||||
""" Experimental
|
""" Experimental
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape, eps=1e-5):
|
def __init__(self, normalized_shape, eps=1e-5):
|
||||||
super(MixedFusedLayerNorm1D, self).__init__()
|
super(MixedFusedLayerNorm1D, self).__init__()
|
||||||
|
|
||||||
|
|
|
@ -121,9 +121,10 @@ class classifier_2d(torch.autograd.Function):
|
||||||
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
|
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
|
||||||
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
|
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
|
||||||
B_grad = B_grad.reshape(ctx.B_shape)
|
B_grad = B_grad.reshape(ctx.B_shape)
|
||||||
|
bias_grad = None
|
||||||
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
|
if ctx.use_bias:
|
||||||
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
|
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
|
||||||
|
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
|
||||||
|
|
||||||
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
|
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
@ -174,9 +175,9 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||||
col_group = gpc.get_group(col_parallel_mode)
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
|
|
||||||
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
opa = [None] * 2
|
opa = [None] * 2
|
||||||
opb = [None] * 2
|
opb = [None] * 2
|
||||||
|
@ -279,9 +280,9 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||||
col_group = gpc.get_group(col_parallel_mode)
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
|
|
||||||
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
opb = [None] * 2
|
opb = [None] * 2
|
||||||
opr = [None] * 2
|
opr = [None] * 2
|
||||||
|
@ -393,9 +394,9 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||||
col_group = gpc.get_group(col_parallel_mode)
|
col_group = gpc.get_group(col_parallel_mode)
|
||||||
|
|
||||||
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||||
pipeline_parallel_rank * tensor_parallel_size
|
pipeline_parallel_rank * tensor_parallel_size
|
||||||
|
|
||||||
opa = [None] * 2
|
opa = [None] * 2
|
||||||
opr = [None] * 2
|
opr = [None] * 2
|
||||||
|
|
|
@ -38,3 +38,9 @@ class PipelineSharedModuleWrapper:
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
setattr(p, 'pipeline_shared_module_pg', self.group)
|
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||||
dist.broadcast(p, src, group=self.group)
|
dist.broadcast(p, src, group=self.group)
|
||||||
|
|
||||||
|
def register_parameter(self, param: nn.Parameter):
|
||||||
|
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||||
|
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||||
|
setattr(param, 'pipeline_shared_module_pg', self.group)
|
||||||
|
dist.broadcast(param, src, group=self.group)
|
||||||
|
|
|
@ -25,6 +25,7 @@ class Metric(ABC):
|
||||||
:param epoch_only: Whether the metric only read for the full epoch
|
:param epoch_only: Whether the metric only read for the full epoch
|
||||||
:type epoch_only: bool
|
:type epoch_only: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, epoch_only: bool):
|
def __init__(self, epoch_only: bool):
|
||||||
# is the metric only read for the full epoch
|
# is the metric only read for the full epoch
|
||||||
self._epoch_only = epoch_only
|
self._epoch_only = epoch_only
|
||||||
|
@ -82,6 +83,7 @@ class LossMetric(Metric):
|
||||||
:param epoch_only: Whether the metric only read for the full epoch
|
:param epoch_only: Whether the metric only read for the full epoch
|
||||||
:type epoch_only: bool
|
:type epoch_only: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, epoch_only):
|
def __init__(self, epoch_only):
|
||||||
super().__init__(epoch_only=epoch_only)
|
super().__init__(epoch_only=epoch_only)
|
||||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||||
|
@ -132,6 +134,7 @@ class LearningRateMetric(Metric):
|
||||||
:param epoch_only: Whether the metric only read for the full epoch
|
:param epoch_only: Whether the metric only read for the full epoch
|
||||||
:type epoch_only: bool
|
:type epoch_only: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||||
super().__init__(epoch_only=epoch_only)
|
super().__init__(epoch_only=epoch_only)
|
||||||
self.lr = initial_lr
|
self.lr = initial_lr
|
||||||
|
@ -159,6 +162,7 @@ class AccuracyMetric(Metric):
|
||||||
:param epoch_only: Whether the metric only read for the full epoch
|
:param epoch_only: Whether the metric only read for the full epoch
|
||||||
:type epoch_only: bool
|
:type epoch_only: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||||
super().__init__(epoch_only=epoch_only)
|
super().__init__(epoch_only=epoch_only)
|
||||||
self.acc = accuracy_func
|
self.acc = accuracy_func
|
||||||
|
@ -217,6 +221,7 @@ class MetricHook(BaseHook):
|
||||||
:type trainer: Trainer
|
:type trainer: Trainer
|
||||||
:type priority: int
|
:type priority: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
priority: int,
|
priority: int,
|
||||||
|
@ -238,6 +243,7 @@ class LossHook(MetricHook):
|
||||||
:type trainer: Trainer
|
:type trainer: Trainer
|
||||||
:type priority: int, optional
|
:type priority: int, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, priority: int = 0):
|
def __init__(self, priority: int = 0):
|
||||||
super().__init__(priority)
|
super().__init__(priority)
|
||||||
|
|
||||||
|
@ -278,6 +284,7 @@ class AccuracyHook(MetricHook):
|
||||||
:type trainer: Trainer
|
:type trainer: Trainer
|
||||||
:type priority: int
|
:type priority: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||||
super().__init__(priority)
|
super().__init__(priority)
|
||||||
self.accuracy_func = accuracy_func
|
self.accuracy_func = accuracy_func
|
||||||
|
@ -351,13 +358,17 @@ class ThroughputHook(MetricHook):
|
||||||
trainer.states['metrics']['test']['Throughput'] = self.metric
|
trainer.states['metrics']['test']['Throughput'] = self.metric
|
||||||
|
|
||||||
def before_train_epoch(self, trainer):
|
def before_train_epoch(self, trainer):
|
||||||
self.metric.reset()
|
if self._is_stage_to_compute:
|
||||||
|
self.metric.reset()
|
||||||
|
|
||||||
def after_train_iter(self, trainer, *args):
|
def after_train_iter(self, trainer, *args):
|
||||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
if self._is_stage_to_compute:
|
||||||
|
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||||
|
|
||||||
def before_test(self, trainer):
|
def before_test(self, trainer):
|
||||||
self.metric.reset()
|
if self._is_stage_to_compute:
|
||||||
|
self.metric.reset()
|
||||||
|
|
||||||
def after_test_iter(self, trainer, *args):
|
def after_test_iter(self, trainer, *args):
|
||||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
if self._is_stage_to_compute:
|
||||||
|
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||||
|
|
|
@ -133,7 +133,7 @@ class GPTBlock(CheckpointModule):
|
||||||
dtype: dtype = None,
|
dtype: dtype = None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
checkpoint: bool = False):
|
checkpoint: bool = False):
|
||||||
super().__init__()
|
super().__init__(checkpoint=checkpoint)
|
||||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||||
self.attn = GPTSelfAttention(dim=dim,
|
self.attn = GPTSelfAttention(dim=dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
|
|
Loading…
Reference in New Issue