mirror of https://github.com/hpcaitech/ColossalAI
[example] titans for gpt (#2484)
parent
7c31706227
commit
3a21485ead
|
@ -39,9 +39,15 @@ If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-
|
|||
For simplicity, the input data is randonly generated here.
|
||||
|
||||
## Training
|
||||
We provide two solutions. One utilizes the hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism.
|
||||
The other one uses Pipeline Parallelism Only.
|
||||
In the future, we are going merge them together and they can be used orthogonally to each other.
|
||||
We provide two stable solutions.
|
||||
One utilizes the Gemini to implement hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism for a huggingface GPT model.
|
||||
The other one use [Titans](https://github.com/hpcaitech/Titans), a distributed executed model zoo maintained by ColossalAI,to implement the hybrid parallel strategies of TP + ZeRO + PP.
|
||||
|
||||
We recommend using Gemini to qucikly run your model in a distributed manner.
|
||||
It doesn't require significant changes to the model structures, therefore you can apply it on a new model easily.
|
||||
And use Titans as an advanced weapon to pursue a more extreme performance.
|
||||
Titans has included the some typical models, such as Vit and GPT.
|
||||
However, it requires some efforts to start if facing a new model structure.
|
||||
|
||||
### GeminiDPP/ZeRO + Tensor Parallelism
|
||||
```bash
|
||||
|
@ -56,6 +62,11 @@ The `train_gpt_demo.py` provides three distributed plans, you can choose the pla
|
|||
- Pytorch DDP
|
||||
- Pytorch ZeRO
|
||||
|
||||
### Titans (Tensor Parallelism) + ZeRO + Pipeline Parallelism
|
||||
|
||||
Titans provides a customized GPT model, which uses distributed operators as building blocks.
|
||||
In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP.
|
||||
You can switch parallel strategies using a config file.
|
||||
|
||||
## Performance
|
||||
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,48 @@
|
|||
# Run GPT With Colossal-AI
|
||||
|
||||
## How to Prepare Webtext Dataset
|
||||
|
||||
You can download the preprocessed sample dataset for this demo via our [Google Drive sharing link](https://drive.google.com/file/d/1QKI6k-e2gJ7XgS8yIpgPPiMmwiBP_BPE/view?usp=sharing).
|
||||
|
||||
|
||||
You can also avoid dataset preparation by using `--use_dummy_dataset` during running.
|
||||
|
||||
## Run this Demo
|
||||
|
||||
Use the following commands to install prerequisites.
|
||||
|
||||
```bash
|
||||
# assuming using cuda 11.3
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Use the following commands to execute training.
|
||||
|
||||
```Bash
|
||||
#!/usr/bin/env sh
|
||||
# if you want to use real dataset, then remove --use_dummy_dataset
|
||||
# export DATA=/path/to/small-gpt-dataset.json'
|
||||
|
||||
# run on a single node
|
||||
colossalai run --nproc_per_node=<num_gpus> train_gpt.py --config configs/<config_file> --from_torch --use_dummy_dataset
|
||||
|
||||
# run on multiple nodes with slurm
|
||||
colossalai run --nproc_per_node=<num_gpus> \
|
||||
--master_addr <hostname> \
|
||||
--master_port <port-number> \
|
||||
--hosts <list-of-hostname-separated-by-comma> \
|
||||
train_gpt.py \
|
||||
--config configs/<config_file> \
|
||||
--from_torch \
|
||||
--use_dummy_dataset
|
||||
|
||||
# run on multiple nodes with slurm
|
||||
srun python \
|
||||
train_gpt.py \
|
||||
--config configs/<config_file> \
|
||||
--host <master_node> \
|
||||
--use_dummy_dataset
|
||||
|
||||
```
|
||||
|
||||
You can set the `<config_file>` to any file in the `configs` folder. To simply get it running, you can start with `gpt_small_zero3_pp1d.py` on a single node first. You can view the explanations in the config file regarding how to change the parallel setting.
|
|
@ -0,0 +1,31 @@
|
|||
from model import GPT2_small_pipeline_hybrid
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
|
||||
BATCH_SIZE = 8
|
||||
NUM_EPOCHS = 10
|
||||
SEQ_LEN = 1024
|
||||
NUM_MICRO_BATCHES = 4
|
||||
HIDDEN_SIZE = 768
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
|
||||
|
||||
# if you do no want zero, just comment out this dictionary
|
||||
zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
|
||||
optimizer_config=dict(initial_scale=2**16))
|
||||
|
||||
optimizer = dict(
|
||||
type=HybridAdam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1)
|
||||
|
||||
# pipeline parallel: modify integer value for the number of pipeline stages
|
||||
# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node
|
||||
# for the current model implementation, mode can only be 1D or None
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=2, mode='1d'),
|
||||
)
|
|
@ -0,0 +1,31 @@
|
|||
from model import GPT3_pipeline_hybrid
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
|
||||
BATCH_SIZE = 192
|
||||
NUM_EPOCHS = 60
|
||||
SEQ_LEN = 2048
|
||||
NUM_MICRO_BATCHES = 192
|
||||
HIDDEN_SIZE = 12288
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
|
||||
|
||||
# if you do no want zero, just comment out this dictionary
|
||||
zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
|
||||
optimizer_config=dict(initial_scale=2**16))
|
||||
|
||||
optimizer = dict(
|
||||
type=HybridAdam,
|
||||
lr=0.00015,
|
||||
weight_decay=1e-2,
|
||||
)
|
||||
|
||||
model = dict(type=GPT3_pipeline_hybrid, checkpoint=True, num_chunks=1)
|
||||
|
||||
# pipeline parallel: modify integer value for the number of pipeline stages
|
||||
# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node
|
||||
# for the current model implementation, mode can only be 1D or None
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=2, mode='1d'), # for the current model implementation, mode can only be 1D or None
|
||||
)
|
|
@ -0,0 +1,3 @@
|
|||
from .embed import vocab_parallel_cross_entropy
|
||||
from .gpt1d import *
|
||||
from .pipeline_gpt1d import *
|
|
@ -0,0 +1,599 @@
|
|||
import torch
|
||||
import torch.nn.init as init
|
||||
from torch import Tensor
|
||||
from torch import distributed as dist
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
|
||||
from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.registry import LAYERS, LOSSES, MODELS
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
num_tokentypes=0,
|
||||
dtype=torch.float):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
# Initialize the position embeddings.
|
||||
# self.init_method(self.position_embeddings.weight)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
|
||||
# Initialize the token-type embeddings.
|
||||
# self.init_method(self.tokentype_embeddings.weight)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
def zero_parameters(self):
|
||||
"""Zero out all parameters in embedding."""
|
||||
self.word_embeddings.weight.data.fill_(0)
|
||||
self.word_embeddings.weight.shared = True
|
||||
self.position_embeddings.weight.data.fill_(0)
|
||||
self.position_embeddings.weight.shared = True
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings.weight.data.fill_(0)
|
||||
self.tokentype_embeddings.weight.shared = True
|
||||
|
||||
def add_tokentype_embeddings(self, num_tokentypes):
|
||||
"""Add token-type embedding. This function is provided so we can add
|
||||
token-type embeddings in case the pretrained model does not have it.
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
# self.init_method(self.tokentype_embeddings.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
|
||||
# Embeddings.
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
|
||||
# Dropout.
|
||||
with seed(ParallelMode.TENSOR):
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Word embedding.
|
||||
if self._word_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._word_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
if self._position_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._position_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_ = {}
|
||||
if self._tokentype_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._tokentype_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it',
|
||||
flush=True)
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None):
|
||||
super(VocabParallelEmbedding1D, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
# Set the details for compatibility.
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
self.tensor_model_parallel_size = gpc.tensor_parallel_size
|
||||
# Divide the weight matrix along the vocabulary dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = \
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
|
||||
self.tensor_model_parallel_size)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||
self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | \
|
||||
(input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
# Mask the output embedding.
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
return output
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class vocab_parallel_cross_entropy(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, vocab_parallel_logits, target):
|
||||
"""Helper function for the cross entropy."""
|
||||
vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
|
||||
target = target[..., 1:].contiguous()
|
||||
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
|
||||
target.view(-1))
|
||||
|
||||
|
||||
class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
# Get the partition's vocab indices
|
||||
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
|
||||
partition_vocab_size = vocab_parallel_logits.size()[-1]
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
world_size = gpc.tensor_parallel_size
|
||||
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
|
||||
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
|
||||
masked_target = target.clone() - vocab_start_index
|
||||
masked_target[target_mask] = 0
|
||||
|
||||
# Get predicted-logits = logits[target].
|
||||
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
||||
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
|
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
torch.distributed.all_reduce(sum_exp_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
loss = loss.mean()
|
||||
# Store softmax, target-mask and masked-target for backward pass.
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as their gradient.
|
||||
grad_input = softmax
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
class VocabUtility:
|
||||
"""Split the vocabulary into `world_size` chunks amd return the
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
|
||||
|
||||
|
||||
class VocabParallelGPTLMHead1D(ParallelLayer):
|
||||
"""
|
||||
Language model head that shares the same parameters with the embedding matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, embed=None, vocab_size=None, dtype=None, embed_dim=None):
|
||||
super().__init__()
|
||||
if embed is not None:
|
||||
self.head = embed
|
||||
else:
|
||||
self.head = VocabParallelEmbedding1D(vocab_size, embed_dim, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = reduce_grad(x, ParallelMode.PARALLEL_1D)
|
||||
x = F.linear(x, self.head.weight)
|
||||
return x
|
||||
|
||||
|
||||
###################################
|
||||
|
||||
|
||||
class HiddenParallelEmbedding(torch.nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
dtype=torch.float,
|
||||
padding_idx: int = 0,
|
||||
num_tokentypes=0,
|
||||
):
|
||||
super(HiddenParallelEmbedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
# Initialize the position embeddings.
|
||||
# self.init_method(self.position_embeddings.weight)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
# self.init_method(self.tokentype_embeddings.weight)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
def zero_parameters(self):
|
||||
"""Zero out all parameters in embedding."""
|
||||
self.word_embeddings.weight.data.fill_(0)
|
||||
self.word_embeddings.weight.shared = True
|
||||
self.position_embeddings.weight.data.fill_(0)
|
||||
self.position_embeddings.weight.shared = True
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings.weight.data.fill_(0)
|
||||
self.tokentype_embeddings.weight.shared = True
|
||||
|
||||
def add_tokentype_embeddings(self, num_tokentypes):
|
||||
"""Add token-type embedding. This function is provided so we can add
|
||||
token-type embeddings in case the pretrained model does not have it.
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
# self.init_method(self.tokentype_embeddings.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
|
||||
# Dropout.
|
||||
with seed(ParallelMode.TENSOR):
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
|
||||
return state_dict_
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
"""Customized load."""
|
||||
|
||||
# Word embedding.
|
||||
if self._word_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._word_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
if self._position_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._position_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_ = {}
|
||||
if self._tokentype_embeddings_key in state_dict:
|
||||
state_dict_ = state_dict[self._tokentype_embeddings_key]
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it',
|
||||
flush=True)
|
||||
|
||||
|
||||
class HiddenParallelEmbedding1D(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx: int = None, init_method=None):
|
||||
super(HiddenParallelEmbedding1D, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
|
||||
# Set the details for compatibility.
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
def forward(self, input_):
|
||||
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class HiddenParallelGPTLMHead1D(ParallelLayer):
|
||||
"""
|
||||
Language model head that shares the same parameters with the embedding matrix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed=None,
|
||||
embed_dim=None,
|
||||
vocab_size=None,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
if embed is not None:
|
||||
self.head = embed
|
||||
self.synced_embed = True
|
||||
else:
|
||||
# self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
|
||||
# (hidden_size/q, vocab_size)
|
||||
self.synced_embed = False
|
||||
self.head = Linear1D_Row(in_features=embed_dim,
|
||||
out_features=vocab_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_input=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if self.synced_embed:
|
||||
x = F.linear(x, self.head.weight)
|
||||
else:
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,349 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn as nn
|
||||
|
||||
from colossalai import kernel
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
from colossalai.nn.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.nn.layer.utils import ACT2FN, divide
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils.activation_checkpoint import checkpoint
|
||||
|
||||
__all__ = [
|
||||
'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
|
||||
]
|
||||
|
||||
|
||||
class GPTMLP1D(ParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
self.act = ACT2FN[act_func]
|
||||
skip_dense_1_add_bias = False
|
||||
|
||||
# Project to mlp_ratio * h.
|
||||
self.dense_1 = Linear1D_Col(
|
||||
self.in_features,
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
dtype=dtype,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_dense_1_add_bias,
|
||||
)
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear1D_Row(
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
self.in_features,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
)
|
||||
|
||||
self.dropout = col_nn.Dropout(dropout_prob)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
|
||||
output = self.dense_2(intermediate_output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, False, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
class GenericGPTSelfAttention1D(ParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
|
||||
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
|
||||
self.checkpoint = checkpoint
|
||||
self.query_key_value = Linear1D_Col(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.attention_dropout = col_nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear1D_Row(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
)
|
||||
self.dropout = col_nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
|
||||
raise NotImplementedError
|
||||
|
||||
def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = self.softmax_forward(attention_scores, attention_mask, query_layer, key_layer)
|
||||
|
||||
attention_scores = attention_scores.type(value_layer.dtype)
|
||||
|
||||
attention_probs = self.attention_dropout(attention_scores)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
output = self.dense(context_layer)
|
||||
output = self.dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
return checkpoint(self._forward, False, hidden_states, attention_mask)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states, attention_mask)
|
||||
else:
|
||||
return self._forward(hidden_states, attention_mask)
|
||||
|
||||
|
||||
class GPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024):
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
max_positions = max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions),
|
||||
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
|
||||
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# causal mask
|
||||
query_length, key_length = query_layer.size(-2), key_layer.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
|
||||
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_scores = self.softmax(attention_scores)
|
||||
return attention_scores
|
||||
|
||||
|
||||
class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024):
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings)
|
||||
self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
|
||||
input_in_bf16=False,
|
||||
attn_mask_type=AttnMaskType.causal,
|
||||
scaled_masked_softmax_fusion=True,
|
||||
mask_func=None,
|
||||
softmax_in_fp32=True,
|
||||
scale=math.sqrt(self.attention_head_size))
|
||||
|
||||
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
|
||||
return self.softmax(attention_scores, attention_mask)
|
||||
|
||||
|
||||
class GenericGPTTransformerLayer1D(ParallelLayer):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
attention=None,
|
||||
layer_norm=None):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.dtype = dtype
|
||||
self.norm1 = layer_norm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.apply_post_layer_norm = apply_post_layer_norm
|
||||
self.attention = attention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
checkpoint=False,
|
||||
)
|
||||
|
||||
self.norm2 = layer_norm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.mlp = GPTMLP1D(
|
||||
in_features=hidden_size,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype,
|
||||
checkpoint=False,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states, attention_mask) -> Tensor:
|
||||
if not self.apply_post_layer_norm:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
if self.apply_post_layer_norm:
|
||||
residual = hidden_states
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
hidden_states = residual + attention_output
|
||||
|
||||
if not self.apply_post_layer_norm:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
if self.apply_post_layer_norm:
|
||||
residual = hidden_states
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
output = (hidden_states, attention_mask)
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
if self.checkpoint:
|
||||
return checkpoint(self._forward, False, hidden_states, attention_mask)
|
||||
else:
|
||||
return self._forward(hidden_states, attention_mask)
|
||||
|
||||
|
||||
class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False):
|
||||
attention = GPTSelfAttention1D
|
||||
layer_norm = nn.LayerNorm
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm)
|
||||
|
||||
|
||||
class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False):
|
||||
attention = FusedGPTSelfAttention1D
|
||||
layer_norm = kernel.LayerNorm
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm)
|
|
@ -0,0 +1,322 @@
|
|||
import inspect
|
||||
|
||||
# import model_zoo.gpt.gpt as col_gpt
|
||||
import titans.model.gpt.gpt as col_gpt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai import kernel
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
|
||||
from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
|
||||
|
||||
__all__ = [
|
||||
'GPT2_small_pipeline_1D',
|
||||
'GPT2_exlarge_pipeline_1D',
|
||||
'GPT3_pipeline_1D',
|
||||
'GPT2_exlarge_pipeline_hybrid',
|
||||
'GPT2_small_pipeline_hybrid',
|
||||
'GPT3_pipeline_hybrid',
|
||||
]
|
||||
|
||||
|
||||
class GenericPipelineGPT(nn.Module):
|
||||
|
||||
def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
|
||||
super().__init__()
|
||||
self.embedding = embedding
|
||||
self.blocks = blocks
|
||||
self.norm = norm
|
||||
self.head = head
|
||||
assert blocks is not None
|
||||
if norm is not None or head is not None:
|
||||
assert norm is not None and head is not None
|
||||
|
||||
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||||
if self.embedding is not None:
|
||||
hidden_states = self.embedding(input_ids=input_ids)
|
||||
batch_size = hidden_states.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
if self.norm is not None:
|
||||
hidden_states = self.head(self.norm(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
embed_cls = VocabParallelEmbedding
|
||||
head_cls = VocabParallelGPTLMHead1D
|
||||
if embed_split_hidden:
|
||||
embed_cls = HiddenParallelEmbedding
|
||||
head_cls = HiddenParallelGPTLMHead1D
|
||||
if first:
|
||||
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
GPTTransformerLayer1D(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||||
])
|
||||
if last:
|
||||
norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||||
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||||
|
||||
|
||||
class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
embed_cls = VocabParallelEmbedding
|
||||
head_cls = VocabParallelGPTLMHead1D
|
||||
if embed_split_hidden:
|
||||
embed_cls = HiddenParallelEmbedding
|
||||
head_cls = HiddenParallelGPTLMHead1D
|
||||
if first:
|
||||
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
FusedGPTTransformerLayer1D(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||||
])
|
||||
if last:
|
||||
norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||||
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||||
|
||||
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||||
if self.embedding is not None:
|
||||
hidden_states = self.embedding(input_ids=input_ids)
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
if self.norm is not None:
|
||||
hidden_states = self.head(self.norm(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PipelineGPTHybrid(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
if first:
|
||||
embedding = col_gpt.GPTEmbedding(hidden_size,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
dropout=embed_drop_rate,
|
||||
dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
col_gpt.GPTBlock(hidden_size,
|
||||
num_attention_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout=attn_drop_rate,
|
||||
dropout=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
activation=nn.functional.gelu) for _ in range(num_layers)
|
||||
])
|
||||
if last:
|
||||
norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
# head = col_gpt.GPTLMHead(vocab_size=vocab_size,
|
||||
# hidden_size=hidden_size,
|
||||
# dtype=dtype,
|
||||
# bias=False)
|
||||
head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False)
|
||||
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
|
||||
|
||||
|
||||
def _filter_kwargs(func, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
else:
|
||||
pipeline_size = 1
|
||||
pipeline_rank = 0
|
||||
rank = gpc.get_global_rank()
|
||||
|
||||
if pipeline_size > 1:
|
||||
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
|
||||
else:
|
||||
wrapper = None
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = end - start
|
||||
kwargs['first'] = start == 0
|
||||
kwargs['last'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
|
||||
|
||||
if wrapper is not None:
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
elif end == num_layers:
|
||||
wrapper.register_module(chunk.head)
|
||||
models.append(chunk)
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
|
||||
numel = 0
|
||||
for _, param in model.named_parameters(recurse=True):
|
||||
numel += param.numel()
|
||||
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
|
||||
return model
|
||||
|
||||
|
||||
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
|
||||
model = FusedPipelineGPT1D if fused else PipelineGPT1D
|
||||
return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
|
||||
|
||||
|
||||
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
|
||||
|
||||
|
||||
def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)
|
|
@ -0,0 +1,4 @@
|
|||
torch==1.12.1
|
||||
titans==0.0.7
|
||||
colossalai==0.2.0+torch1.12cu11.3
|
||||
-f https://release.colossalai.org
|
|
@ -0,0 +1,2 @@
|
|||
export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
|
||||
colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch
|
|
@ -0,0 +1 @@
|
|||
colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch --use_dummy_dataset
|
|
@ -0,0 +1,148 @@
|
|||
import contextlib
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from titans.model.gpt import GPTLMLoss
|
||||
|
||||
import colossalai
|
||||
import colossalai.utils as utils
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import LinearWarmupLR
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
|
||||
from colossalai.utils.timer import MultiTimer
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
|
||||
|
||||
def calc_local_model_size(model: torch.nn.Module):
|
||||
numel_per_device = 0
|
||||
for p in model.parameters():
|
||||
numel_per_device += p.numel()
|
||||
return numel_per_device
|
||||
|
||||
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
|
||||
def main():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('--from_torch', default=False, action='store_true')
|
||||
parser.add_argument('--use_dummy_dataset', default=True, action='store_true')
|
||||
args = parser.parse_args()
|
||||
disable_existing_loggers()
|
||||
if args.from_torch:
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
else:
|
||||
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
|
||||
logger = get_dist_logger()
|
||||
|
||||
if not args.use_dummy_dataset:
|
||||
data_path = os.environ['DATA']
|
||||
logger.info(f'Build data loader from path {data_path}', ranks=[0])
|
||||
from dataset.webtext import WebtextDataset
|
||||
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
|
||||
train_dataloader = utils.get_dataloader(train_ds,
|
||||
seed=42,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
else:
|
||||
# build a dummy train_dataloader
|
||||
logger.info('Build data loader using dummy data', ranks=[0])
|
||||
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
# 10 iterations
|
||||
input_ids, attn_mask = get_data(gpc.config.BATCH_SIZE * 10, gpc.config.SEQ_LEN, VOCAB_SIZE)
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
class TextSamplerDataset(Dataset):
|
||||
|
||||
def __init__(self, data, seq_len):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
self.seq_len = seq_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
||||
full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
|
||||
return full_seq.cuda()
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0) // self.seq_len
|
||||
|
||||
def cycle(loader):
|
||||
while True:
|
||||
for data in loader:
|
||||
yield data
|
||||
|
||||
train_dataset = TextSamplerDataset(input_ids, gpc.config.SEQ_LEN)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=gpc.config.BATCH_SIZE)
|
||||
|
||||
logger.info('Build model', ranks=[0])
|
||||
use_pipeline = is_using_pp()
|
||||
use_interleaved = hasattr(gpc.config.model, 'num_chunks')
|
||||
use_zero3 = hasattr(gpc.config, 'zero')
|
||||
ctx = contextlib.nullcontext()
|
||||
if use_zero3:
|
||||
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True)
|
||||
with ctx:
|
||||
model = gpc.config.model.pop('type')(**gpc.config.model)
|
||||
if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList):
|
||||
model = nn.ModuleList([model])
|
||||
|
||||
if use_zero3:
|
||||
numel = ctx.model_numel_tensor.item()
|
||||
else:
|
||||
numel = calc_local_model_size(model)
|
||||
|
||||
tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
|
||||
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
|
||||
|
||||
criterion = getattr(gpc.config, 'loss_fn', None)
|
||||
if criterion is not None:
|
||||
criterion = criterion.type()
|
||||
else:
|
||||
criterion = GPTLMLoss()
|
||||
logger.info('Build optimizer', ranks=[0])
|
||||
optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
|
||||
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
|
||||
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
|
||||
optimizer,
|
||||
criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
global_batch_size = gpc.config.BATCH_SIZE * \
|
||||
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
|
||||
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
|
||||
timier = MultiTimer()
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timier)
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
|
||||
hooks.LogMetricByStepHook(),
|
||||
hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogTimingByEpochHook(timer, logger),
|
||||
]
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
return_output_label=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -11,13 +11,12 @@ import tqdm
|
|||
from packaging import version
|
||||
from palm_pytorch import PaLM
|
||||
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import MultiTimer, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
|
Loading…
Reference in New Issue