mirror of https://github.com/hpcaitech/ColossalAI
commit
593a72e4d5
|
@ -6,11 +6,13 @@ on:
|
|||
- cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
publish:
|
||||
if: github.repository == 'hpcaitech/ColossalAI'
|
||||
name: Build and publish Python 🐍 distributions 📦 to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
outputs:
|
||||
status: ${{ steps.publish.outcome }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
|
@ -18,7 +20,9 @@ jobs:
|
|||
with:
|
||||
python-version: '3.8.14'
|
||||
|
||||
- run: NIGHTLY=1 python setup.py sdist build
|
||||
- run: |
|
||||
python .github/workflows/scripts/update_setup_for_nightly.py
|
||||
python setup.py sdist build
|
||||
|
||||
# publish to PyPI if executed on the main branch
|
||||
- name: Publish package to PyPI
|
||||
|
@ -31,7 +35,7 @@ jobs:
|
|||
|
||||
notify:
|
||||
name: Notify Lark via webhook
|
||||
needs: build-n-publish
|
||||
needs: publish
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ always() }} && github.repository == 'hpcaitech/ColossalAI'
|
||||
steps:
|
||||
|
@ -62,4 +66,4 @@ jobs:
|
|||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
|
||||
STATUS: ${{ steps.publish.outcome }}
|
||||
STATUS: ${{ needs.publish.outputs.status }}
|
||||
|
|
|
@ -49,6 +49,6 @@ jobs:
|
|||
# we need to install the requirements.txt first
|
||||
# as test-pypi may not contain the distributions for libs listed in the txt file
|
||||
pip install -r requirements/requirements.txt
|
||||
pip install --index-url https://test.pypi.org/simple/ colossalai==$VERSION
|
||||
pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.python.org/pypi colossalai==$VERSION
|
||||
env:
|
||||
VERSION: ${{ steps.prep-version.outputs.version }}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
from datetime import datetime
|
||||
|
||||
|
||||
def open_setup_file():
|
||||
with open("setup.py", "r") as f:
|
||||
file_lines = f.readlines()
|
||||
return file_lines
|
||||
|
||||
|
||||
def replace_nightly_package_info(file_lines):
|
||||
version = datetime.today().strftime("%Y.%m.%d")
|
||||
package_name = "colossalai-nightly"
|
||||
|
||||
for idx, line in enumerate(file_lines):
|
||||
if "version = get_version()" in line:
|
||||
file_lines[idx] = f'version = "{version}"\n'
|
||||
if 'package_name = "colossalai"' in line:
|
||||
file_lines[idx] = f'package_name = "{package_name}"\n'
|
||||
return file_lines
|
||||
|
||||
|
||||
def write_setup_file(file_lines):
|
||||
with open("setup.py", "w") as f:
|
||||
f.writelines(file_lines)
|
||||
|
||||
|
||||
def main():
|
||||
file_lines = open_setup_file()
|
||||
file_lines = replace_nightly_package_info(file_lines)
|
||||
write_setup_file(file_lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
10
README.md
10
README.md
|
@ -9,7 +9,7 @@
|
|||
<a href="https://www.colossalai.org/"> Documentation </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> Examples </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
|
||||
<a href="https://medium.com/@hpcaitech"> Blog </a></h3>
|
||||
<a href="https://hpc-ai.com/blog"> Blog </a></h3>
|
||||
|
||||
[![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers)
|
||||
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml)
|
||||
|
@ -398,10 +398,10 @@ pip install colossalai
|
|||
|
||||
**Note: only Linux is supported for now.**
|
||||
|
||||
However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`.
|
||||
However, if you want to build the PyTorch extensions during installation, you can set `BUILD_EXT=1`.
|
||||
|
||||
```bash
|
||||
CUDA_EXT=1 pip install colossalai
|
||||
BUILD_EXT=1 pip install colossalai
|
||||
```
|
||||
|
||||
**Otherwise, CUDA kernels will be built during runtime when you actually need them.**
|
||||
|
@ -429,7 +429,7 @@ By default, we do not compile CUDA/C++ kernels. ColossalAI will build them durin
|
|||
If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
|
||||
|
||||
```shell
|
||||
CUDA_EXT=1 pip install .
|
||||
BUILD_EXT=1 pip install .
|
||||
```
|
||||
|
||||
For Users with CUDA 10.2, you can still build ColossalAI from source. However, you need to manually download the cub library and copy it to the corresponding directory.
|
||||
|
@ -445,7 +445,7 @@ unzip 1.8.0.zip
|
|||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
|
||||
# install
|
||||
CUDA_EXT=1 pip install .
|
||||
BUILD_EXT=1 pip install .
|
||||
```
|
||||
|
||||
<p align="right">(<a href="#top">back to top</a>)</p>
|
||||
|
|
|
@ -49,12 +49,13 @@ def _preprocess(
|
|||
max_length: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
sequences = [s + t for s, t in zip(sources, targets)]
|
||||
sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)]
|
||||
sequences_token = tokenizer(
|
||||
sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
|
||||
)
|
||||
|
||||
sources_token = tokenizer(
|
||||
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
|
||||
)
|
||||
|
||||
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
|
||||
|
@ -65,7 +66,8 @@ def _preprocess(
|
|||
if tokenizer.padding_side == "right":
|
||||
# |prompt|completion|eos|pad|
|
||||
labels[i][:source_len] = IGNORE_INDEX
|
||||
labels[i][-pad_len:] = IGNORE_INDEX
|
||||
if pad_len>0:
|
||||
labels[i][-pad_len:] = IGNORE_INDEX
|
||||
elif tokenizer.padding_side == "left":
|
||||
# |pad|prompt|completion|eos|
|
||||
labels[i][: pad_len + source_len] = IGNORE_INDEX
|
||||
|
|
|
@ -25,4 +25,4 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
|
|||
--accumulation_steps 8 \
|
||||
--lr 2e-5 \
|
||||
--max_datasets_size 512 \
|
||||
--max_epochs 1
|
||||
--max_epochs 1
|
|
@ -1,20 +1,16 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from datasets import Dataset as HFDataset
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
PathType = Union[str, os.PathLike]
|
||||
|
@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
|
|||
tokenizer: PreTrainedTokenizer
|
||||
max_length: int = 4096
|
||||
ignore_index: int = -100
|
||||
padding: str = "max_length"
|
||||
|
||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
@ -106,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
|
|||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
if self.padding == "max_length":
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
elif self.tokenizer.padding_side == "left":
|
||||
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
|
||||
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
|
@ -171,49 +169,3 @@ class StatefulDistributedSampler(DistributedSampler):
|
|||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def setup_distributed_dataloader(
|
||||
dataset: DatasetType,
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Setup dataloader for distributed training.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset=dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id: int) -> None:
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs,
|
||||
)
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaForCausalLM,
|
||||
|
@ -19,194 +19,334 @@ from transformers.models.llama.modeling_llama import (
|
|||
repeat_kv,
|
||||
)
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
if get_accelerator().name == "cuda":
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
),
|
||||
q,
|
||||
),
|
||||
q,
|
||||
),
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(
|
||||
output, pattern="... h d -> ... (h d)"
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
||||
|
||||
elif get_accelerator().name == "npu":
|
||||
import torch_npu
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
||||
class NPULlamaAttention(LlamaAttention):
|
||||
use_flash: bool = True
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self._softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if not self.use_flash:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
attn_output, *_ = torch_npu.npu_fusion_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self.num_heads,
|
||||
"BNSD",
|
||||
atten_mask=attention_mask.bool(),
|
||||
scale=self._softmax_scale,
|
||||
padding_mask=None,
|
||||
pre_tockens=65535,
|
||||
next_tockens=0,
|
||||
keep_prob=1.0,
|
||||
inner_precise=0,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum(
|
||||
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class NPURMSNorm(LlamaRMSNorm):
|
||||
def forward(self, hidden_states):
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.__class__ = NPULlamaAttention
|
||||
module.setup()
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.__class__ = NPURMSNorm
|
||||
|
|
|
@ -17,7 +17,7 @@ import torch
|
|||
|
||||
def unwrap(model):
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
return model.unwrap()
|
||||
else:
|
||||
return model
|
||||
|
||||
|
|
|
@ -1,22 +1,21 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from colossal_llama2.dataset.conversation import default_conversation
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def load_model(model_path, device="cuda", **kwargs):
|
||||
logger.info(
|
||||
"Please check whether the tokenizer and model weights are properly stored in the same folder."
|
||||
)
|
||||
logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.")
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
||||
model.to(device)
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
|
||||
except OSError:
|
||||
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
|
||||
|
||||
|
@ -27,31 +26,51 @@ def load_model(model_path, device="cuda", **kwargs):
|
|||
def generate(args):
|
||||
model, tokenizer = load_model(model_path=args.model_path, device=args.device)
|
||||
|
||||
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
|
||||
input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
|
||||
if args.prompt_style == "sft":
|
||||
conversation = default_conversation.copy()
|
||||
conversation.append_message("Human", args.input_txt)
|
||||
conversation.append_message("Assistant", None)
|
||||
input_txt = conversation.get_prompt()
|
||||
else:
|
||||
BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
|
||||
input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
|
||||
|
||||
inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device)
|
||||
output = model.generate(**inputs,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
do_sample=args.do_sample,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
num_return_sequences=1)
|
||||
response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):]
|
||||
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
|
||||
inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
|
||||
num_input_tokens = inputs["input_ids"].shape[-1]
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
do_sample=args.do_sample,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
num_return_sequences=1,
|
||||
)
|
||||
response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
|
||||
logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
|
||||
parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model")
|
||||
parser.add_argument('--device', type=str, default="cuda:0", help="Set the device")
|
||||
parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt")
|
||||
parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling")
|
||||
parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value")
|
||||
parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering")
|
||||
parser.add_argument('--top_p', type=int, default=0.95, help="Set top_p value for generation")
|
||||
parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="hpcai-tech/Colossal-LLaMA-2-7b-base",
|
||||
help="HF repo name or local path of the model",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default="cuda:0", help="Set the device")
|
||||
parser.add_argument(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
|
||||
)
|
||||
parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling")
|
||||
parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
|
||||
parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
|
||||
parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model")
|
||||
parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt")
|
||||
args = parser.parse_args()
|
||||
generate(args)
|
||||
generate(args)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
torch<2.0.0, >=1.12.1
|
||||
packaging==23.1
|
||||
colossalai==0.3.2
|
||||
colossalai==0.3.5
|
||||
autoflake==2.2.1
|
||||
black==23.9.1
|
||||
transformers
|
||||
transformers==4.33.3
|
||||
tensorboard==2.14.0
|
||||
six==1.16.0
|
||||
datasets
|
||||
|
|
|
@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.
|
|||
--warmup_steps 100 \
|
||||
--use_grad_checkpoint \
|
||||
--use_flash_attn \
|
||||
--pad_token "unk"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -16,22 +16,24 @@ from colossal_llama2.dataset.loader import (
|
|||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
|
@ -83,6 +85,7 @@ def main() -> None:
|
|||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
||||
|
@ -108,6 +111,12 @@ def main() -> None:
|
|||
default=False,
|
||||
help="Use flash-attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_neft",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use NEFTune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_non_embeds_params",
|
||||
action="store_true",
|
||||
|
@ -116,6 +125,8 @@ def main() -> None:
|
|||
)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=1)
|
||||
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
||||
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
|
@ -125,6 +136,7 @@ def main() -> None:
|
|||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
accelerator = get_accelerator()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
|
@ -142,6 +154,7 @@ def main() -> None:
|
|||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -149,6 +162,7 @@ def main() -> None:
|
|||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
|
@ -182,7 +196,10 @@ def main() -> None:
|
|||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||
# ======================================================
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
if args.pad_token == "eos":
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.pad_token == "unk":
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
|
@ -193,38 +210,36 @@ def main() -> None:
|
|||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = setup_distributed_dataloader(
|
||||
data_collator = DataCollatorForSupervisedDataset(
|
||||
tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.micro_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
|
||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
||||
try:
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
current_device = get_accelerator().get_current_device()
|
||||
except:
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
current_device = get_current_device()
|
||||
|
||||
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
||||
# Freeze part of parameters.
|
||||
if args.freeze_non_embeds_params:
|
||||
freeze_non_embeds_parameters(model=model)
|
||||
# this is essential, otherwise the grad checkpoint will not work.
|
||||
model.train()
|
||||
|
||||
if args.use_grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
@ -246,12 +261,14 @@ def main() -> None:
|
|||
adamw_mode=True,
|
||||
)
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps
|
||||
if args.warmup_steps is not None
|
||||
else int(args.num_epochs * len(dataloader) * 0.025),
|
||||
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
|
@ -267,11 +284,9 @@ def main() -> None:
|
|||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
if args.load_checkpoint is None:
|
||||
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
|
||||
booster.load_model(model, args.pretrained, strict=False)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
@ -298,85 +313,109 @@ def main() -> None:
|
|||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
|
||||
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
||||
# If resume training, set the sampler start index to the correct value
|
||||
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
||||
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
with tqdm(
|
||||
iterable=enumerate(dataloader, start=start_step),
|
||||
pbar = tqdm(
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
initial=start_step // args.accumulation_steps,
|
||||
)
|
||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||
for step, batch in enumerate(dataloader, start=start_step):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
batch_output = model(**batch)
|
||||
|
||||
loss = batch_output.loss
|
||||
loss = batch_output.loss / args.accumulation_steps
|
||||
total_loss.add_(loss.data)
|
||||
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
all_reduce_mean(tensor=loss)
|
||||
pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
|
||||
all_reduce_mean(tensor=total_loss)
|
||||
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||
if coordinator.is_master():
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
|
||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
tag="Learning Rate",
|
||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||
global_step=global_step,
|
||||
)
|
||||
# Save modeling.
|
||||
total_loss.fill_(0.0)
|
||||
pbar.update()
|
||||
# Save modeling.
|
||||
|
||||
if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.micro_batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||
)
|
||||
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||
step + 1
|
||||
) == len(dataloader):
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
|
||||
# Delete CUDA cache.
|
||||
# del batch, batch_labels, batch_output, loss
|
||||
torch.cuda.empty_cache()
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
accelerator.empty_cache()
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.micro_batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
|
||||
# Delete cache.
|
||||
# del batch, batch_labels, batch_output, loss
|
||||
accelerator.empty_cache()
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(start_index=0)
|
||||
start_step = 0
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
|||
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
|
||||
--pretrained $PRETRAINED_MODEL_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
|
@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_
|
|||
--use_grad_checkpoint \
|
||||
--use_flash_attn \
|
||||
--use_neft \
|
||||
--pad_token "eos"
|
||||
|
|
|
@ -1,403 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_llama2.dataset.loader import (
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Address of the pre-trained modeling",
|
||||
)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16"],
|
||||
help="Mixed precision",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument(
|
||||
"--use_grad_checkpoint",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use gradient checkpointing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use flash-attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_neft",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use NEFTune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_non_embeds_params",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Freeze non embeddings parameters",
|
||||
)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if coordinator.is_master():
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=args.zero,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||
# ======================================================
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = setup_distributed_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.micro_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
# Freeze part of parameters.
|
||||
if args.freeze_non_embeds_params:
|
||||
freeze_non_embeds_parameters(model=model)
|
||||
|
||||
if args.use_grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
if args.use_flash_attn:
|
||||
replace_with_flash_attention(model=model)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
||||
optimizer = HybridAdam(
|
||||
model_params=filter(lambda p: p.requires_grad, model.parameters())
|
||||
if args.freeze_non_embeds_params
|
||||
else model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Flash attention will be disabled because it does NOT support fp32.
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
if args.load_checkpoint is None:
|
||||
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
|
||||
booster.load_model(model, args.pretrained, strict=False)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load_checkpoint is not None:
|
||||
if "modeling" in args.load_checkpoint:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
|
||||
booster.load_model(model, args.load_checkpoint)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.load_checkpoint,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
|
||||
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
||||
# If resume training, set the sampler start index to the correct value
|
||||
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
||||
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
|
||||
total_loss = torch.tensor(0.0).to(torch.cuda.current_device())
|
||||
for step, batch in enumerate(dataloader):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
|
||||
loss = batch_output.loss / args.accumulation_steps
|
||||
total_loss += loss.item()
|
||||
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
all_reduce_mean(tensor=total_loss)
|
||||
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||
if coordinator.is_master():
|
||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
tag="Learning Rate",
|
||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||
global_step=global_step,
|
||||
)
|
||||
total_loss.fill_(0.0)
|
||||
pbar.update()
|
||||
# Save modeling.
|
||||
|
||||
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||
step + 1
|
||||
) == len(dataloader):
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.micro_batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
|
||||
# Delete CUDA cache.
|
||||
# del batch, batch_labels, batch_output, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(start_index=0)
|
||||
start_step = 0
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -3,6 +3,8 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .huggingface import HuggingFaceModel
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
@ -126,9 +128,9 @@ class ChatGLMModel(HuggingFaceModel):
|
|||
"""
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
).to(torch.cuda.current_device())
|
||||
).to(get_current_device())
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||
torch.cuda.current_device()
|
||||
get_current_device()
|
||||
)
|
||||
|
||||
outputs = self.model(input_ids)[0]
|
||||
|
@ -197,7 +199,7 @@ class ChatGLM2Model(ChatGLMModel):
|
|||
truncation=True,
|
||||
return_tensors="pt",
|
||||
max_length=self.model_max_length - max_new_tokens,
|
||||
).to(torch.cuda.current_device())
|
||||
).to(get_current_device())
|
||||
|
||||
# Set output_scores=True to get prediction scores.
|
||||
outputs = self.model.generate(
|
||||
|
|
|
@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokeni
|
|||
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import BaseModel
|
||||
|
||||
|
@ -128,12 +129,12 @@ class HuggingFaceModel(BaseModel):
|
|||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
self.model.to(get_current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||
else:
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
self.model.eval()
|
||||
|
@ -155,11 +156,11 @@ class HuggingFaceModel(BaseModel):
|
|||
"""
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
).to(torch.cuda.current_device())
|
||||
).to(get_current_device())
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||
torch.cuda.current_device()
|
||||
get_current_device()
|
||||
)
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device())
|
||||
|
||||
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
|
@ -464,7 +465,7 @@ class HuggingFaceModel(BaseModel):
|
|||
return_tensors="pt",
|
||||
return_token_type_ids=False,
|
||||
max_length=self.model_max_length - max_new_tokens,
|
||||
).to(torch.cuda.current_device())
|
||||
).to(get_current_device())
|
||||
|
||||
# Set output_scores=True to get prediction scores.
|
||||
outputs = self.model.generate(
|
||||
|
@ -598,12 +599,12 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
|||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
self.model.to(get_current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.distributed as dist
|
|||
from colossal_eval import dataset, models, utils
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer import ShardConfig
|
||||
|
@ -82,6 +83,7 @@ def rm_and_merge(
|
|||
|
||||
def main(args):
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
accelerator = get_accelerator()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
@ -235,10 +237,10 @@ def main(args):
|
|||
),
|
||||
)
|
||||
|
||||
logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
|
||||
logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB")
|
||||
|
||||
del model_
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.empty_cache()
|
||||
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,629 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
StateDictSharder,
|
||||
gather_distributed_param,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
|
||||
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
|
||||
self.ep_group = moe_info.ep_group
|
||||
self.ep_size = moe_info.ep_size
|
||||
self.ep_rank = moe_info.ep_rank
|
||||
self.real_dp_rank = moe_info.dp_rank
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(
|
||||
model: nn.Module,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024,
|
||||
param_name_pattern: Optional[str] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
|
||||
# Save parameters.
|
||||
for name, param in model.named_parameters():
|
||||
if param is None:
|
||||
continue
|
||||
if param_name_pattern is not None and param_name_pattern not in name:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Save buffers.
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Save extra states.
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = model.get_extra_state()
|
||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
|
||||
- Multiple files that store state tensors of models.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
|
||||
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a directory path.
|
||||
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
|
||||
prefix (str, optional): Perfix of file to save. Defaults to None.
|
||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model = model.unwrap()
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.real_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
# ep_rank 0 saves all the parameters and buffers.
|
||||
# other ep_ranks save only experts
|
||||
ep_param_pattern = "experts." if self.ep_rank != 0 else None
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
|
||||
weights_name = weights_name.replace(
|
||||
".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors"
|
||||
)
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.coordinator.is_master():
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for weight, weight_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(weight, weight_filename)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gather_from_sharded_optimizer_state(
|
||||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
original_shape: torch.Size,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
is_moe_param: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
|
||||
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
dp_group (ProcessGroup): The process group of data parallel.
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
|
||||
|
||||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
"""
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
tp_size = dist.get_world_size(tp_group)
|
||||
current_shape = param.shape
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero and not is_moe_param:
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
||||
# Then gather TP shards.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
||||
if partition_dim is not None:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
@staticmethod
|
||||
def _optimizer_sharder(
|
||||
optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
size_per_shard: int = 1024,
|
||||
only_moe_param: bool = False,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
is_moe_param=is_moe_tensor(working_param),
|
||||
)
|
||||
|
||||
if only_moe_param and not is_moe_tensor(working_param):
|
||||
continue
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files that store state tensors of optimizers.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
|
||||
checkpoint (str): Path to save optimizer state_dict
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.real_dp_rank != 0:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
only_moe_param=self.ep_rank != 0,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
dist.barrier()
|
||||
return
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.coordinator.is_master():
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for param_id, state_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(param_id, state_filename)
|
||||
|
||||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
saved_groups = torch.load(param_group_path)
|
||||
|
||||
updated_groups = []
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
# obtain updated param group
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
# ep param groups
|
||||
if len(optimizer.optim.param_groups) == len(saved_groups) + 1:
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
||||
# If this param's states has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
continue
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
is_moe_param=is_moe_tensor(working_param),
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def shard_from_complete_optimizer_state(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
current_shape: torch.Size,
|
||||
original_shape: torch.Size,
|
||||
device: torch.device,
|
||||
inplace: bool,
|
||||
is_moe_param: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With complete optimizer states of a specific parameter loaded from checkpoint,
|
||||
slice out the sharded optimizer states kept by current device.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
|
||||
current_shape (torch.Size): The size of parameter after sharding.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
device (torch.device): The destination device of loaded optimizer states.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
||||
Returns:
|
||||
OrderedDict: The sharded optimizer state of the given parameter.
|
||||
"""
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
|
||||
if partition_dim is not None:
|
||||
slice_size = current_shape[partition_dim]
|
||||
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
|
||||
|
||||
# Shard state along data parallel group when using Zero.
|
||||
if self.use_zero and not is_moe_param:
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,92 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
|
||||
|
||||
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.setup_ep()
|
||||
|
||||
def setup_ep(self):
|
||||
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
|
||||
ep_group = moe_info.ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_info(p, moe_info)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_ep()
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
selected_experts = selected_experts.t().reshape(-1)
|
||||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||
)
|
||||
dispatch_states = dispatch_states[recover_experts_idx]
|
||||
k_hidden_states = dispatch_states.chunk(self.top_k)
|
||||
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
|
||||
for i in range(1, self.top_k):
|
||||
output_states += k_hidden_states[i] * routing_weights[:, i, None]
|
||||
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return output_states, router_logits
|
|
@ -0,0 +1,557 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, Module
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralDecoderLayer,
|
||||
MixtralForCausalLM,
|
||||
MixtralModel,
|
||||
MoeCausalLMOutputWithPast,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from .mixtral_layer import EPMixtralSparseMoeBlock
|
||||
|
||||
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
|
||||
|
||||
|
||||
class MixtralPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "MixtralModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
class MixtralModelPolicy(MixtralPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralModel,
|
||||
new_forward=MixtralPipelineForwards.mixtral_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
return []
|
||||
|
||||
|
||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
MixtralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=MixtralForCausalLM,
|
||||
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class MixtralPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def mixtral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
past_router_logits: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
||||
|
||||
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||
)
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if output_router_logits and past_router_logits is not None:
|
||||
all_router_logits = past_router_logits + all_router_logits
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||
if v is not None
|
||||
)
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"past_router_logits": all_router_logits,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def mixtral_for_causal_lm_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
past_router_logits: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
||||
|
||||
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||
)
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = MixtralPipelineForwards.mixtral_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
past_router_logits=past_router_logits,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
if output_router_logits:
|
||||
output = (aux_loss,) + output
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
router_logits=outputs[-1],
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
if output_router_logits:
|
||||
out["past_router_logits"] = outputs["past_router_logits"]
|
||||
return out
|
|
@ -0,0 +1,84 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Save as JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
||||
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
save_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
load_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
booster.load_model(model, os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||
return (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
|
@ -0,0 +1,111 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-v0.1",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="ep",
|
||||
choices=["ep"],
|
||||
help="Parallel methos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp32", "bf16", "fp16"],
|
||||
help="The mixed precision training.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_layernorm_kernel",
|
||||
action="store_true",
|
||||
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
config = MixtralConfig.from_pretrained(args.model_name)
|
||||
ep_size = min(dist.get_world_size(), config.num_local_experts)
|
||||
# Set plugin
|
||||
if args.plugin == "ep":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=ep_size,
|
||||
zero_stage=1,
|
||||
precision=args.precision,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build mixtral model
|
||||
model = MixtralForCausalLM.from_pretrained(args.model_name)
|
||||
coordinator.print_on_master(f"Finish load model")
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin)
|
||||
model, _, _, _, _ = booster.boost(model=model)
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
model.eval()
|
||||
|
||||
if coordinator.rank == 0:
|
||||
text = ["Hello my name is"]
|
||||
else:
|
||||
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
print(f"[{coordinator.rank}] {outputs}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,7 @@
|
|||
NUM_GPU=2
|
||||
MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||
|
||||
# ep
|
||||
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
|
||||
--model_name $MODEL \
|
||||
--plugin "ep" \
|
|
@ -0,0 +1,5 @@
|
|||
colossalai >= 0.3.3
|
||||
torch >= 1.8.1
|
||||
transformers == 4.36.0
|
||||
sentencepiece
|
||||
datasets
|
|
@ -0,0 +1,43 @@
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def fetch_requirements(path):
|
||||
with open(path, "r") as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
||||
|
||||
def fetch_readme():
|
||||
with open("README.md", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def fetch_version():
|
||||
with open("version.txt", "r") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
setup(
|
||||
name="colossal_moe",
|
||||
version=fetch_version(),
|
||||
packages=find_packages(
|
||||
exclude=(
|
||||
"tests",
|
||||
"benchmarks",
|
||||
"*.egg-info",
|
||||
)
|
||||
),
|
||||
description="Colossal-AI MoE",
|
||||
long_description=fetch_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
license="Apache Software License 2.0",
|
||||
url="https://github.com/hpcaitech",
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
python_requires=">=3.6",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Environment :: GPU :: NVIDIA CUDA",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: System :: Distributed Computing",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,63 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
|
||||
from torch.testing import assert_close
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import colossalai
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
|
||||
)
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
orig_model = MixtralSparseMoeBlock(config).cuda()
|
||||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
orig_output, orig_logits = orig_model(x)
|
||||
model = deepcopy(orig_model)
|
||||
model = EPMixtralSparseMoeBlock.from_native_module(model)
|
||||
ep_output, ep_logits = model(x)
|
||||
assert_close(orig_logits, ep_logits)
|
||||
assert_close(orig_output, ep_output)
|
||||
orig_loss = orig_output.mean()
|
||||
orig_loss.backward()
|
||||
ep_loss = ep_output.mean()
|
||||
ep_loss.backward()
|
||||
assert_close(orig_loss, ep_loss)
|
||||
name_to_p = {n: p for n, p in orig_model.named_parameters()}
|
||||
for n, ep_p in model.named_parameters():
|
||||
p = name_to_p[n]
|
||||
if ep_p.grad is not None:
|
||||
assert_close(p.grad, ep_p.grad)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch({}, rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral_moe_layer(2)
|
|
@ -0,0 +1,146 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from torch.optim import Adam
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert torch.equal(p1.half(), p2.half())
|
||||
|
||||
|
||||
def get_optimizer_snapshot(optim):
|
||||
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
||||
param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [id(p) for p in group["params"]]
|
||||
new_group = {"params": params}
|
||||
for k, v in group.items():
|
||||
if k != "params":
|
||||
new_group[k] = v
|
||||
param_groups.append(new_group)
|
||||
return {
|
||||
"state": state,
|
||||
"param_groups": param_groups,
|
||||
}
|
||||
|
||||
|
||||
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
|
||||
# check param_groups
|
||||
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
|
||||
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
|
||||
assert set(group1.keys()) == set(group2.keys())
|
||||
for k in group1.keys():
|
||||
assert group1[k] == group2[k]
|
||||
# check state
|
||||
assert set(snapshot1["state"].keys()) == set(
|
||||
snapshot2["state"].keys()
|
||||
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
|
||||
for pid in snapshot1["state"].keys():
|
||||
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
|
||||
assert set(state1.keys()) == set(state2.keys())
|
||||
for k in state1.keys():
|
||||
if isinstance(state1[k], torch.Tensor):
|
||||
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
|
||||
else:
|
||||
assert state1[k] == state2[k]
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
||||
orig_model = MixtralForCausalLM(config).cuda()
|
||||
model = deepcopy(orig_model)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
ep_size=2,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
microbatch_size=1,
|
||||
zero_stage=1,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
||||
# initialize grads
|
||||
data_iter = iter(
|
||||
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
|
||||
)
|
||||
booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
lambda outputs, inputs: outputs.loss,
|
||||
optimizer,
|
||||
)
|
||||
|
||||
# check save model
|
||||
booster.save_model(model, "mixtral_model", shard=True)
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
|
||||
check_model_equal(orig_model, saved_model)
|
||||
saved_model.save_pretrained("mixtral_hf_model")
|
||||
dist.barrier()
|
||||
|
||||
# check load model
|
||||
new_model = MixtralForCausalLM(config).cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||
booster.load_model(new_model, "mixtral_hf_model")
|
||||
check_model_equal(model, new_model)
|
||||
|
||||
# check save optimizer
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
|
||||
dist.barrier()
|
||||
# reset optimizer state
|
||||
for state in optimizer.unwrap().state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v.zero_()
|
||||
booster.load_optimizer(optimizer, "mixtral_optim")
|
||||
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch({}, rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral_moe_layer(4)
|
|
@ -0,0 +1,295 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_global_loss(loss, booster):
|
||||
global_loss = loss.clone().detach()
|
||||
dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
|
||||
global_loss.div_(booster.plugin.dp_size)
|
||||
return global_loss
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"input_ids": self.input_ids[idx],
|
||||
"attention_mask": self.attention_mask[idx],
|
||||
"labels": self.input_ids[idx],
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
# basic settings
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-v0.1",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["hybrid"],
|
||||
help="Parallel methods.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./outputs",
|
||||
help="The path of your saved model after finetuning.",
|
||||
)
|
||||
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size (per dp group) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_interval",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=" The interval (steps) of saving checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="bf16",
|
||||
choices=["fp32", "bf16", "fp16"],
|
||||
help="The mixed precision training.",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
|
||||
# optim
|
||||
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
|
||||
# lr scheduler
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
|
||||
# zero stage for all plugins
|
||||
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
||||
# hybrid plugin
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
|
||||
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
|
||||
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
|
||||
|
||||
# kernel
|
||||
parser.add_argument(
|
||||
"--use_kernel",
|
||||
action="store_true",
|
||||
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_layernorm_kernel",
|
||||
action="store_true",
|
||||
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||
)
|
||||
|
||||
# load balance
|
||||
parser.add_argument(
|
||||
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
|
||||
)
|
||||
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
|
||||
# communicate overlap
|
||||
parser.add_argument(
|
||||
"--comm_overlap",
|
||||
action="store_true",
|
||||
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
# hierarchical all-to-all
|
||||
parser.add_argument(
|
||||
"--hierarchical_alltoall",
|
||||
action="store_true",
|
||||
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Launch ColossalAI
|
||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# Set plugin
|
||||
if args.plugin == "hybrid":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=args.pp_size,
|
||||
ep_size=args.ep_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
precision=args.precision,
|
||||
zero_stage=args.zero_stage,
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build Mixtral model
|
||||
model = MixtralForCausalLM.from_pretrained(args.model_name)
|
||||
coordinator.print_on_master(f"Finish init model")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
|
||||
collate_fn = None
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# Set optimizer
|
||||
optimizer = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# Set lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps
|
||||
if args.warmup_steps is not None
|
||||
else int(args.num_epochs * len(dataloader) * 0.025),
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Load ckpt
|
||||
if args.load_checkpoint is not None:
|
||||
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
|
||||
coordinator.print_on_master(f"Finish load optimizer")
|
||||
|
||||
# Start finetuning
|
||||
coordinator.print_on_master(f"Start finetuning")
|
||||
for epoch in range(args.num_epoch):
|
||||
model.train()
|
||||
train_dataloader_iter = iter(dataloader)
|
||||
total_len = len(train_dataloader_iter)
|
||||
with tqdm(
|
||||
range(total_len),
|
||||
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
|
||||
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
if use_pipeline:
|
||||
# Forward pass
|
||||
outputs = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
global_loss = get_global_loss(loss, booster)
|
||||
if coordinator._local_rank == "0":
|
||||
pbar.set_postfix({"Loss": global_loss.item()})
|
||||
else:
|
||||
# Forward pass
|
||||
data = next(train_dataloader_iter)
|
||||
data = move_to_cuda(data, torch.cuda.current_device())
|
||||
outputs = model(**data)
|
||||
loss = outputs["loss"]
|
||||
# Backward
|
||||
booster.backward(loss, optimizer)
|
||||
pbar.set_postfix({"loss": loss.item()})
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Apply load balance
|
||||
# if (
|
||||
# args.load_balance
|
||||
# and args.load_balance_interval > 0
|
||||
# and (step + 1) % args.load_balance_interval == 0
|
||||
# ):
|
||||
# coordinator.print_on_master(f"Apply load balance")
|
||||
# apply_load_balance(model, optimizer)
|
||||
# save ckeckpoint
|
||||
if (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
save_checkpoint(
|
||||
args.output_path,
|
||||
booster,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
epoch,
|
||||
step,
|
||||
args.batch_size,
|
||||
coordinator,
|
||||
)
|
||||
|
||||
# save checkpoint at the end of each epochs
|
||||
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
|
||||
# Finish training
|
||||
coordinator.print_on_master(f"Finish training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,19 @@
|
|||
NUM_GPU=8
|
||||
MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||
SEQ_LENGTH=2048
|
||||
BATCH_SIZE=1
|
||||
LR=0.00001
|
||||
|
||||
# hybrid
|
||||
# torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \
|
||||
train.py \
|
||||
--num_epoch 1 \
|
||||
--model_name $MODEL \
|
||||
--plugin "hybrid" \
|
||||
--batch_size $BATCH_SIZE \
|
||||
--lr $LR \
|
||||
--zero_stage 1 \
|
||||
--pp_size 2 \
|
||||
--dp_size 1 \
|
||||
--ep_size 8 \
|
|
@ -0,0 +1 @@
|
|||
1.0.0
|
|
@ -4,9 +4,9 @@ import torch
|
|||
|
||||
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
|
||||
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BCAST_FUNC_OP
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||
|
||||
from ..constants import BCAST_FUNC_OP
|
||||
from ..registry import meta_register
|
||||
|
||||
__all__ = ["binary_elementwise_meta_info"]
|
||||
|
|
|
@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
|
|||
self.world_size = dist.get_world_size()
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
|
|||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
|
|
|
@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
|
|||
return ["cuda", "npu"]
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
||||
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
||||
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||
sampler = DistributedSampler(
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset,
|
||||
num_replicas=zero_world_size * extra_dp_world_size,
|
||||
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import ctypes
|
||||
import random
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
|
@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
if self.dp_size == 1:
|
||||
warnings.warn(
|
||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
|
||||
)
|
||||
|
||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
|
@ -1199,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return outputs
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -1223,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MoECheckpintIO
|
||||
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
|
@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
extra_dp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
|
@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
overlap_communication: bool = True,
|
||||
use_ep_inside: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
checkpoint_io: Optional[MoECheckpintIO] = None,
|
||||
) -> None:
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
|
@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=self.real_dp_size,
|
||||
fixed_ep_size=ep_size,
|
||||
fixed_pp_size=pp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
)
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_info = MOE_MANAGER.get_info(0)[1]
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
|
@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.checkpoint_io = checkpoint_io
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
||||
|
||||
|
@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpintIO:
|
||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
if self.checkpoint_io is None:
|
||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
else:
|
||||
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
|
||||
def configure(
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
|
||||
|
@ -25,7 +27,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
|
@ -74,17 +76,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
prefix: Optional[str],
|
||||
size_per_shard: int,
|
||||
use_safetensors: bool,
|
||||
model: ModelWrapper,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
with FSDP.state_dict_type(
|
||||
model.unwrap(),
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
):
|
||||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||
|
||||
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
utils.save_config_file(model.unwrap(), checkpoint_path)
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
self,
|
||||
|
@ -97,7 +136,24 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Load model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not utils.is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
fsdp_state_dict = {}
|
||||
for shard_file in checkpoint_files:
|
||||
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
|
||||
|
||||
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
|
||||
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
|
@ -105,13 +161,86 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with FSDP.state_dict_type(
|
||||
optimizer.unwrap_model().unwrap(),
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
):
|
||||
fsdp_optim_state = FSDP.full_optim_state_dict(
|
||||
optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
|
||||
)
|
||||
|
||||
if self.coordinator.is_master():
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
utils.save_param_groups(fsdp_optim_state, group_file_path)
|
||||
|
||||
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
# In general cases, is_master is set to True to get the right behavior.
|
||||
total_size = utils.save_state_dict_shards(
|
||||
sharded_state_dict=sharded_state,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=self.coordinator.is_master(),
|
||||
use_safetensors=False,
|
||||
)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {index_file_path} for an optimizer. "
|
||||
"Looking param group file under current directory."
|
||||
)
|
||||
|
||||
saved_param_groups = torch.load(param_group_path)
|
||||
|
||||
# Load param
|
||||
fsdp_optim_state = {}
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
fsdp_optim_state.update(state_dict_shard)
|
||||
|
||||
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
||||
|
||||
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
|
||||
fsdp_state = FSDP.optim_state_dict_to_load(
|
||||
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
|
||||
)
|
||||
optimizer.load_state_dict(fsdp_state)
|
||||
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
|
@ -190,7 +319,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
False
|
||||
return False
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from .utils import has_index_file
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
||||
|
||||
__all__ = ["CheckpointIO"]
|
||||
|
||||
|
@ -90,7 +90,15 @@ class CheckpointIO(ABC):
|
|||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
path = Path(checkpoint, WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
|
||||
return origin_model
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import copy
|
||||
from functools import reduce
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
|
@ -445,7 +446,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
|
@ -504,7 +509,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
@ -713,12 +722,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
tp_group=self.tp_group,
|
||||
use_zero=self.use_zero,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
device=get_current_device(),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
state_dict = {"param_groups": param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
|
@ -729,7 +742,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
state_dict = {"param_groups": param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
@ -838,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero:
|
||||
v = v.cuda()
|
||||
v = v.to(get_current_device())
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .checkpoint import MoECheckpintIO
|
||||
from .experts import MLPExperts
|
||||
from .layers import SparseMLP
|
||||
from .layers import SparseMLP, apply_load_balance
|
||||
from .manager import MOE_MANAGER
|
||||
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||
|
||||
|
@ -14,4 +15,6 @@ __all__ = [
|
|||
"UniformNoiseGenerator",
|
||||
"SparseMLP",
|
||||
"MoECheckpintIO",
|
||||
"MOE_MANAGER",
|
||||
"apply_load_balance",
|
||||
]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
|
|||
if ctx.ep_size != 1:
|
||||
grad = grad / ctx.ep_size
|
||||
return grad, None
|
||||
|
||||
|
||||
def _all_to_all(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
outputs_shape = list(inputs.shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
inputs = inputs.contiguous()
|
||||
outputs = outputs.contiguous()
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
return outputs, handle
|
||||
|
||||
|
||||
class AllToAllUneven(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
inputs,
|
||||
input_split_sizes=None,
|
||||
output_split_sizes=None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
ctx.input_split_sizes = input_split_sizes
|
||||
ctx.output_split_sizes = output_split_sizes
|
||||
ctx.group = group
|
||||
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs):
|
||||
return (
|
||||
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def all_to_all_uneven(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
|
|
@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
f"index located at {save_index_file}."
|
||||
)
|
||||
dist.barrier()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for optimizer loading/saving implementation
|
||||
|
@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
|
@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
# ep extra group
|
||||
if MOE_MANAGER.parallel == "EP":
|
||||
# ep param group
|
||||
if len(optimizer.optim.param_groups) > len(saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1][
|
||||
"params"
|
||||
] # Only keep the parameters kept by current pipeline stage.
|
||||
for param in new_pg["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
|
@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for pid, state in list(state_dict.items()):
|
||||
if pid in id_map:
|
||||
param = id_map[pid]
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif (
|
||||
hasattr(optimizer, "moe_master_to_working_map")
|
||||
and id(param) in optimizer.moe_master_to_working_map
|
||||
):
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
working_param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device="cpu",
|
||||
inplace=True,
|
||||
)
|
||||
state_dict[pid] = sharded_state
|
||||
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
|
@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
|
@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
|
|
|
@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
|
|||
self.ep_size = 1
|
||||
|
||||
if gated:
|
||||
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
|
||||
self.wi_gate = nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
|
||||
)
|
||||
)
|
||||
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
else:
|
||||
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||
|
|
|
@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
|
|||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
router_top_k: int = 1,
|
||||
router_loss: bool = True,
|
||||
router_norm: bool = False,
|
||||
router_capacity_factor_train: float = 1.25,
|
||||
router_capacity_factor_eval: float = 2.0,
|
||||
router_min_capacity: int = 4,
|
||||
|
@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
|
|||
enable_kernel: bool = False,
|
||||
enable_comm_overlap: bool = False,
|
||||
enable_hierarchical_comm: bool = False,
|
||||
return_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_experts = num_experts
|
||||
self.gated = mlp_gated
|
||||
self.return_gate_logits = return_gate_logits
|
||||
self.enable_kernel = enable_kernel
|
||||
self.enable_comm_overlap = enable_comm_overlap
|
||||
self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||
self.router_loss = router_loss
|
||||
self.router_norm = router_norm
|
||||
|
||||
# moe router
|
||||
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
||||
|
@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
|
|||
tokens = inputs.reshape(-1, self.hidden_size)
|
||||
|
||||
# the data type of the inputs in the gating should be fp32
|
||||
fp32_input = tokens.to(torch.float)
|
||||
fp32_weight = self.gate_weight.to(torch.float)
|
||||
gate_output = F.linear(fp32_input, fp32_weight)
|
||||
gate_logits = F.linear(tokens, self.gate_weight)
|
||||
gate_output = gate_logits.to(torch.float)
|
||||
|
||||
# update expert load
|
||||
if self.enable_load_balance == True:
|
||||
|
@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
|
|||
|
||||
# the result from the router
|
||||
used_capacity, *route_result_list = self.router(
|
||||
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
|
||||
inputs=gate_output,
|
||||
use_kernel=self.enable_kernel,
|
||||
ep_group=self.ep_group,
|
||||
use_loss=self.router_loss,
|
||||
use_norm=self.router_norm,
|
||||
)
|
||||
|
||||
# dispatch_data: (num_experts, capacity, hidden_size)
|
||||
if self.enable_kernel:
|
||||
|
@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
|
|||
|
||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||
if self.expert_parallel == "EP":
|
||||
expert_output = self._ep_process(
|
||||
dispatch_data,
|
||||
used_capacity,
|
||||
overlap=self.enable_comm_overlap
|
||||
)
|
||||
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel == "TP":
|
||||
expert_output = self._tp_process(
|
||||
dispatch_data,
|
||||
used_capacity,
|
||||
overlap=self.enable_comm_overlap
|
||||
)
|
||||
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||
elif self.expert_parallel is None:
|
||||
expert_output = self._local_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
|
||||
"Please use Experts build function.")
|
||||
raise NotImplementedError(
|
||||
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
|
||||
)
|
||||
|
||||
if self.enable_kernel:
|
||||
expert_output = expert_output.reshape(-1, self.hidden_size)
|
||||
|
@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
|
|||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
ans = ans.reshape(inputs.shape)
|
||||
return ans
|
||||
|
||||
if self.return_gate_logits:
|
||||
return ans, gate_logits
|
||||
else:
|
||||
return ans
|
||||
|
||||
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
||||
expert_in = expert_in.unsqueeze(0)
|
||||
|
@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
|
|||
return expert_out
|
||||
|
||||
def _ep_process(
|
||||
self,
|
||||
dispatch_data: torch.Tensor,
|
||||
used_capacity: torch.Tensor,
|
||||
overlap: bool = False
|
||||
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expert Parallel
|
||||
|
@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
|
|||
"""
|
||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||
if self.ep_hierarchical_group is not None:
|
||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||
expert_input = HierarchicalAllToAll.apply(
|
||||
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||
)
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
||||
expert_output = HierarchicalAllToAll.apply(
|
||||
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||
)
|
||||
return expert_output
|
||||
else:
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||
|
@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
|
|||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
|
||||
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
|
||||
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
||||
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
dispatch_data = dispatch_data.reshape(*input_shape)
|
||||
|
@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
|
|||
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
||||
if expert_out is not None:
|
||||
expert_out.handle.wait()
|
||||
output[:, :, offset:offset + chunk_size, :] = expert_out.data
|
||||
output[:, :, offset : offset + chunk_size, :] = expert_out.data
|
||||
offset += chunk_size
|
||||
expert_out = None
|
||||
|
||||
# all2all last output
|
||||
if _expert_out is not None:
|
||||
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
|
||||
expert_out = Capsule(
|
||||
*AllToAll.apply(_expert_out.data, self.ep_group, True),
|
||||
)
|
||||
_expert_out = None
|
||||
|
||||
# all2all next input
|
||||
|
@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
|
|||
return output
|
||||
|
||||
def _tp_process(
|
||||
self,
|
||||
dispatch_data: torch.Tensor,
|
||||
used_capacity: torch.Tensor,
|
||||
overlap: bool = False
|
||||
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
without overlap:
|
||||
|
@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
|
|||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
|
||||
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
assert (
|
||||
dispatch_data.shape[0] % NUM_CHUNK == 0
|
||||
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
||||
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
||||
output = torch.empty_like(dispatch_data)
|
||||
|
|
|
@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
|
|||
self._z_loss = None
|
||||
self.use_kernel = use_kernel
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
def get_capacity(self, num_tokens, num_experts, ep_group=None):
|
||||
if ep_group is not None:
|
||||
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
|
||||
dist.all_reduce(num_tokens_tensor, group=ep_group)
|
||||
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
|
@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
|
|||
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_loss: bool = False,
|
||||
use_norm: bool = False,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
|
|||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
num_experts = probs.size(-1)
|
||||
capacity = self.get_capacity(inputs.shape)
|
||||
num_tokens = inputs.size(0)
|
||||
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
|
|||
weight = mask * probs.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return used_capacity, combine_weights, sec_mask
|
||||
return used_capacity, combine_weights, sec_mask, probs
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
|
@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
|
|||
drop_tks=drop_tks,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
use_kernel: bool = False,
|
||||
ep_group: Optional[ProcessGroup] = None,
|
||||
use_norm: bool = False,
|
||||
use_loss: bool = True,
|
||||
) -> Tuple:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||
|
@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
|
|||
|
||||
assert inputs.dtype == torch.float
|
||||
probs = F.softmax(inputs, dim=-1)
|
||||
if use_norm:
|
||||
routing_weights, _ = torch.topk(probs, 2, dim=-1)
|
||||
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
num_experts = probs.size(-1)
|
||||
capacity = self.get_capacity(inputs.shape)
|
||||
num_tokens = inputs.size(0)
|
||||
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
||||
|
||||
top1_idx = torch.argmax(probs, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
@ -270,10 +294,11 @@ class Top2Router(MoeRouter):
|
|||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||
|
||||
# calculate loss
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
if use_loss:
|
||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||
self.set_z_loss(inputs)
|
||||
self.pop_router_loss()
|
||||
|
||||
if not self.training and not self.drop_tks and ep_group is not None:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
|
|
|
@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
|
|||
return torch.nn.GELU()
|
||||
elif act == "swiglu":
|
||||
return SwiGLU
|
||||
elif act == "silu":
|
||||
return torch.nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError("Unsupported activation function")
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ if Version(torch.__version__) >= Version("2.0.0"):
|
|||
else:
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class _enable_get_lr_call:
|
||||
def __init__(self, o):
|
||||
|
@ -19,7 +21,39 @@ class _enable_get_lr_call:
|
|||
self.o._get_lr_called_within_step = False
|
||||
|
||||
|
||||
class DelayerScheduler(_LRScheduler):
|
||||
class TwoStageScheduler(_LRScheduler):
|
||||
def __init__(self, optimizer, after_scheduler: _LRScheduler, last_epoch=-1):
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
|
||||
if isinstance(state_dict["after_scheduler"], _LRScheduler):
|
||||
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
|
||||
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
|
||||
del state_dict["after_scheduler"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if "after_scheduler_dict" not in state_dict:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(
|
||||
"after_scheduler_dict is not found, skip loading after_scheduler. This may cause unexpected behavior."
|
||||
)
|
||||
else:
|
||||
self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"])
|
||||
state_dict = {
|
||||
key: value
|
||||
for key, value in state_dict.items()
|
||||
if key not in ("after_scheduler_type", "after_scheduler_dict")
|
||||
}
|
||||
super().load_state_dict(state_dict)
|
||||
|
||||
|
||||
class DelayerScheduler(TwoStageScheduler):
|
||||
"""Starts with a flat lr schedule until it reaches N epochs then applies
|
||||
the specific scheduler (For example: ReduceLROnPlateau)
|
||||
|
||||
|
@ -35,19 +69,7 @@ class DelayerScheduler(_LRScheduler):
|
|||
if delay_epochs < 0:
|
||||
raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}")
|
||||
self.delay_epochs = delay_epochs
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
|
||||
if isinstance(state_dict["after_scheduler"], _LRScheduler):
|
||||
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
|
||||
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
|
||||
del state_dict["after_scheduler"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
super().__init__(optimizer, after_scheduler, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.delay_epochs:
|
||||
|
@ -71,7 +93,7 @@ class DelayerScheduler(_LRScheduler):
|
|||
return super(DelayerScheduler, self).step(epoch)
|
||||
|
||||
|
||||
class WarmupScheduler(_LRScheduler):
|
||||
class WarmupScheduler(TwoStageScheduler):
|
||||
"""Starts with a linear warmup lr schedule until it reaches N epochs then applies
|
||||
the specific scheduler (For example: ReduceLROnPlateau).
|
||||
|
||||
|
@ -85,19 +107,7 @@ class WarmupScheduler(_LRScheduler):
|
|||
|
||||
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
||||
self.warmup_epochs = int(warmup_epochs)
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
|
||||
if isinstance(state_dict["after_scheduler"], _LRScheduler):
|
||||
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
|
||||
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
|
||||
del state_dict["after_scheduler"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
super().__init__(optimizer, after_scheduler, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs:
|
||||
|
@ -120,7 +130,7 @@ class WarmupScheduler(_LRScheduler):
|
|||
return super().step(epoch)
|
||||
|
||||
|
||||
class WarmupDelayerScheduler(_LRScheduler):
|
||||
class WarmupDelayerScheduler(TwoStageScheduler):
|
||||
"""Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule
|
||||
until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau).
|
||||
|
||||
|
@ -140,19 +150,7 @@ class WarmupDelayerScheduler(_LRScheduler):
|
|||
raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}")
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.delay_epochs = delay_epochs
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
|
||||
if isinstance(state_dict["after_scheduler"], _LRScheduler):
|
||||
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
|
||||
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
|
||||
del state_dict["after_scheduler"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
super().__init__(optimizer, after_scheduler, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs + self.delay_epochs:
|
||||
|
|
|
@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer._operation import _gather
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
@ -288,6 +289,9 @@ class LlamaPipelineForwards:
|
|||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
|
|
@ -34,6 +34,7 @@ class ShardConfig:
|
|||
enable_all_optimization: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
parallel_output = True
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
|
|
|
@ -7,11 +7,12 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
|||
|
||||
from .colo_tensor import _convert_output
|
||||
|
||||
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point}
|
||||
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
|
||||
NO_HOOK_FUNCS = {torch.Tensor.is_floating_point}
|
||||
|
||||
|
||||
def is_no_hook_op(func) -> bool:
|
||||
return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS
|
||||
return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS
|
||||
|
||||
|
||||
def filter_colo_parameters(*args, **kwargs):
|
||||
|
|
|
@ -26,3 +26,5 @@ class MoeParallelInfo:
|
|||
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
|
||||
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
|
||||
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
|
||||
self.ep_rank = self.pg.coordinate(self.ep_axis)
|
||||
self.dp_rank = self.pg.coordinate(self.dp_axis)
|
||||
|
|
|
@ -92,7 +92,10 @@ class ColoParamOpHookManager:
|
|||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||
ColoParamOpHookManager._trigger_post_forward(params)
|
||||
return PostFwdPreBwd.apply(params, arg)
|
||||
# incase the output is a tuple, we have to flatten it
|
||||
grad_args, other_args, grad_flags, spec = _flatten_grad_args(arg)
|
||||
new_grad_args = PostFwdPreBwd.apply(params, *grad_args)
|
||||
return _merge_args(new_grad_args, other_args, grad_flags, spec)
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
|
@ -113,7 +116,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
|||
|
||||
class PostFwdPreBwd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, params, args):
|
||||
def forward(ctx, params, *args):
|
||||
ctx.params = params
|
||||
return args
|
||||
|
||||
|
@ -142,7 +145,6 @@ def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
|
|||
grad_args.append(arg)
|
||||
else:
|
||||
other_args.append(arg)
|
||||
assert len(grad_args) > 0
|
||||
return grad_args, other_args, grad_flags, spec
|
||||
|
||||
|
||||
|
|
|
@ -726,11 +726,13 @@ class GeminiDDP(ModelWrapper):
|
|||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
if self.reuse_fp16_chunk:
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.payload.copy_(chunk_32.payload)
|
||||
|
||||
# sync running weights and master weights
|
||||
if self.master_weights:
|
||||
for loaded_chunk in chunk_list:
|
||||
paired_chunk = loaded_chunk.paired_chunk
|
||||
assert paired_chunk is not None
|
||||
paired_chunk.payload.copy_(loaded_chunk.payload)
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
|
|
|
@ -621,7 +621,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
Return the param_groups in Pytorch format when saving to checkpoint.
|
||||
"""
|
||||
|
||||
param_groups = copy.deepcopy(self.param_groups_backup)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(self.optim.param_groups, self.param_groups_backup)
|
||||
]
|
||||
|
||||
# To be compatible with pytorch checkpointing,
|
||||
# store extra hyperparameters used by pytorch Adam optimizer.
|
||||
|
|
|
@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# because they have different parallel strategy
|
||||
# so we need to store them separately in param_groups
|
||||
# instead of working_groups
|
||||
moe_params = list()
|
||||
self.working_moe_params = list()
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if self.moe_extra_dp_pg is None:
|
||||
# skip moe param
|
||||
if is_moe_tensor(param):
|
||||
moe_params.append(param)
|
||||
self.working_moe_params.append(param)
|
||||
continue
|
||||
group_params.append(param)
|
||||
|
||||
|
@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# managed by this data parallel rank
|
||||
param_group["params"] = master_param_current_rank
|
||||
|
||||
# if there are moe params, store in additional group in optim
|
||||
if len(moe_params) > 0:
|
||||
# if there are moe params, store in addtional group in optim
|
||||
if len(self.working_moe_params) > 0:
|
||||
self._sync_master_param = False
|
||||
param_group = dict()
|
||||
# create fp32 master param
|
||||
for key, value in self.optim.param_groups[0].items():
|
||||
if key != "params":
|
||||
param_group[key] = value
|
||||
param_group["params"] = moe_params
|
||||
self.master_moe_params = []
|
||||
for param in self.working_moe_params:
|
||||
self.master_moe_params.append(param.clone().to(torch.float32).detach())
|
||||
# create mapping from master to working for optimizer io
|
||||
self.moe_master_to_working_map = {}
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
|
||||
# add to optim
|
||||
param_group["params"] = self.master_moe_params
|
||||
self.optim.param_groups.append(param_group)
|
||||
|
||||
# initialize communication stream for
|
||||
|
@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# update the params in the optimizer
|
||||
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
||||
|
||||
# update param for moe ep
|
||||
# move grad to master param and compute norm
|
||||
if len(self.working_moe_params) > 0:
|
||||
moe_grads = []
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
if master_moe_param.grad is not None:
|
||||
raise RuntimeError("Moe param should not have grad here")
|
||||
grad = working_moe_param.grad
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
if self._master_weights:
|
||||
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
|
||||
master_moe_param.grad = grad
|
||||
working_moe_param.grad = None
|
||||
moe_grads.append(grad)
|
||||
grad_partition_groups.append(grad)
|
||||
norm_group = self._compute_grad_norm(gradients=moe_grads)
|
||||
norm_groups.append(norm_group)
|
||||
self.optim.param_groups[-1]["params"] = self.master_moe_params
|
||||
del moe_grads
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||
|
||||
# TODO: we should store master param for ep
|
||||
if len(self.param_groups) > len(self._working_param_groups):
|
||||
for param in self.param_groups[-1]["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
param.grad = param.grad.to(torch.float32)
|
||||
|
||||
# update the parameters
|
||||
self.optim.step()
|
||||
|
||||
# release the moe gradm
|
||||
if len(self.param_groups) > len(self._working_param_groups):
|
||||
for param in self.param_groups[-1]["params"]:
|
||||
param.grad = None
|
||||
param.data = param.data.to(self._dtype)
|
||||
# release moe grad
|
||||
if len(self.working_moe_params) > 0:
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.grad = None
|
||||
working_moe_param.data = (
|
||||
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
|
||||
)
|
||||
|
||||
# release the grad
|
||||
grad_partition_groups = []
|
||||
|
@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
||||
else:
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
if hasattr(self, "master_moe_params"):
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.copy_(working_moe_param)
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.working_to_master_param
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
if hasattr(self, "moe_master_to_working_map"):
|
||||
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
|
||||
return self._param_store.master_to_working_param
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
<a href="https://www.colossalai.org/"> 文档 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> 例程 </a> |
|
||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
|
||||
<a href="https://medium.com/@hpcaitech"> 博客 </a></h3>
|
||||
<a href="https://hpc-ai.com/blog"> 博客 </a></h3>
|
||||
|
||||
[![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers)
|
||||
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml)
|
||||
|
|
|
@ -23,7 +23,7 @@ pip install colossalai
|
|||
If you want to build PyTorch extensions during installation, you can use the command below. Otherwise, the PyTorch extensions will be built during runtime.
|
||||
|
||||
```shell
|
||||
CUDA_EXT=1 pip install colossalai
|
||||
BUILD_EXT=1 pip install colossalai
|
||||
```
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ cd ColossalAI
|
|||
pip install -r requirements/requirements.txt
|
||||
|
||||
# install colossalai
|
||||
CUDA_EXT=1 pip install .
|
||||
BUILD_EXT=1 pip install .
|
||||
```
|
||||
|
||||
If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `CUDA_EXT`:
|
||||
|
@ -61,7 +61,7 @@ unzip 1.8.0.zip
|
|||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
|
||||
# install
|
||||
CUDA_EXT=1 pip install .
|
||||
BUILD_EXT=1 pip install .
|
||||
```
|
||||
|
||||
<!-- doc-test-command: echo "installation.md does not need test" -->
|
||||
|
|
|
@ -19,10 +19,8 @@
|
|||
当第二个线性层 $Z=YB$ 跟随上述列并行层的时候, 我们把 $B$ 划分为
|
||||
$$
|
||||
\left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right]
|
||||
```
|
||||
这就是所谓的行并行方式.
|
||||
$$
|
||||
|
||||
这就是所谓的行并行方式.
|
||||
为了计算
|
||||
$$
|
||||
Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right]
|
||||
|
|
|
@ -20,10 +20,10 @@ pip install colossalai
|
|||
|
||||
**注:现在只支持Linux。**
|
||||
|
||||
如果你想同时安装PyTorch扩展的话,可以添加`CUDA_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。
|
||||
如果你想同时安装PyTorch扩展的话,可以添加`BUILD_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。
|
||||
|
||||
```shell
|
||||
CUDA_EXT=1 pip install colossalai
|
||||
BUILD_EXT=1 pip install colossalai
|
||||
```
|
||||
|
||||
## 从源安装
|
||||
|
@ -38,10 +38,10 @@ cd ColossalAI
|
|||
pip install -r requirements/requirements.txt
|
||||
|
||||
# install colossalai
|
||||
CUDA_EXT=1 pip install .
|
||||
BUILD_EXT=1 pip install .
|
||||
```
|
||||
|
||||
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`CUDA_EXT=1`:
|
||||
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`BUILD_EXT=1`:
|
||||
|
||||
```shell
|
||||
pip install .
|
||||
|
@ -60,7 +60,7 @@ unzip 1.8.0.zip
|
|||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
|
||||
# install
|
||||
CUDA_EXT=1 pip install .
|
||||
BUILD_EXT=1 pip install .
|
||||
```
|
||||
|
||||
<!-- doc-test-command: echo "installation.md does not need test" -->
|
||||
|
|
|
@ -1,84 +0,0 @@
|
|||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
SUPPORT_XFORMERS = False
|
||||
SUPPORT_FLASH2 = False
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
|
||||
SUPPORT_XFORMERS = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
SUPPORT_FLASH2 = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2
|
||||
|
||||
|
||||
def llama_flash_attention(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K]
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if SUPPORT_FLASH2:
|
||||
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
|
||||
else:
|
||||
attn_output = xops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def replace_xformers(model: nn.Module):
|
||||
for module in model.modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(llama_flash_attention, module)
|
|
@ -0,0 +1 @@
|
|||
../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
|
|
@ -3,7 +3,7 @@ import resource
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from attn import SUPPORT_FLASH, replace_xformers
|
||||
from attn import replace_with_flash_attention
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
from performance_evaluator import PerformanceEvaluator
|
||||
|
@ -188,8 +188,7 @@ def main():
|
|||
model.gradient_checkpointing_enable()
|
||||
|
||||
if args.xformers:
|
||||
assert SUPPORT_FLASH, "Use flash attention while xfomers is not installed"
|
||||
replace_xformers(model)
|
||||
replace_with_flash_attention(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Optional, Tuple
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from attn import SUPPORT_XFORMERS, replace_xformers
|
||||
from attn import replace_with_flash_attention
|
||||
from data_utils import load_json, prepare_dataloader, save_json
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Optimizer
|
||||
|
@ -219,8 +219,7 @@ def main():
|
|||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.flash_attention:
|
||||
assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed"
|
||||
replace_xformers(model)
|
||||
replace_with_flash_attention(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Optional, Tuple
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from attn import SUPPORT_XFORMERS, replace_xformers
|
||||
from attn import replace_with_flash_attention
|
||||
from data_utils import load_json, prepare_dataloader, save_json
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Optimizer
|
||||
|
@ -238,8 +238,7 @@ def main():
|
|||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
if args.flash_attention:
|
||||
assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed"
|
||||
replace_xformers(model)
|
||||
replace_with_flash_attention(model)
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
|
|
@ -126,7 +126,7 @@ class _CppExtension(_Extension):
|
|||
def load(self):
|
||||
try:
|
||||
op_kernel = self.import_op()
|
||||
except ImportError:
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
# if import error occurs, it means that the kernel is not pre-built
|
||||
# so we build it jit
|
||||
op_kernel = self.build_jit()
|
||||
|
|
24
setup.py
24
setup.py
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
@ -15,7 +14,6 @@ except ImportError:
|
|||
|
||||
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
BUILD_EXT = int(os.environ.get("BUILD_EXT", "0")) == 1
|
||||
IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1
|
||||
|
||||
# we do not support windows currently
|
||||
if sys.platform == "win32":
|
||||
|
@ -96,23 +94,15 @@ if BUILD_EXT:
|
|||
else:
|
||||
ext_modules = []
|
||||
|
||||
# always put not nightly branch as the if branch
|
||||
# otherwise github will treat colossalai-nightly as the project name
|
||||
# and it will mess up with the dependency graph insights
|
||||
if not IS_NIGHTLY:
|
||||
version = get_version()
|
||||
package_name = "colossalai"
|
||||
else:
|
||||
# use date as the nightly version
|
||||
version = datetime.today().strftime("%Y.%m.%d")
|
||||
package_name = "colossalai-nightly"
|
||||
version = get_version()
|
||||
package_name = "colossalai"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version=version,
|
||||
packages=find_packages(
|
||||
exclude=(
|
||||
"op_builder",
|
||||
"extensions",
|
||||
"benchmark",
|
||||
"docker",
|
||||
"tests",
|
||||
|
@ -121,8 +111,9 @@ setup(
|
|||
"tests",
|
||||
"scripts",
|
||||
"requirements",
|
||||
"extensions",
|
||||
"*.egg-info",
|
||||
)
|
||||
),
|
||||
),
|
||||
description="An integrated large-scale model training system with efficient parallelization techniques",
|
||||
long_description=fetch_readme(),
|
||||
|
@ -153,10 +144,7 @@ setup(
|
|||
],
|
||||
package_data={
|
||||
"colossalai": [
|
||||
"_C/*.pyi",
|
||||
"kernel/cuda_native/csrc/*",
|
||||
"kernel/cuda_native/csrc/kernel/*",
|
||||
"kernel/cuda_native/csrc/kernels/include/*",
|
||||
"kernel/extensions/csrc/**/*",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
|
|
@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
|
|||
@parameterize(
|
||||
"test_args",
|
||||
[
|
||||
{
|
||||
"batch_size": 8,
|
||||
"num_steps": 4,
|
||||
"tp": 2,
|
||||
"pp": 2,
|
||||
"pp_style": "1f1b",
|
||||
"num_model_chunks": 1,
|
||||
"num_microbatches": 4,
|
||||
"zero": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"max_length": 512,
|
||||
"gradient_accumulation_step": 2,
|
||||
},
|
||||
{
|
||||
"batch_size": 8,
|
||||
"num_steps": 4,
|
||||
|
|
|
@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
new_model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
|
@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
|
@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
check_state_dict_equal(
|
||||
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
|
||||
)
|
||||
for group in new_optimizer.param_groups:
|
||||
assert group["lr"] == 0.1
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
|
|
@ -83,7 +83,8 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
optimizer.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.booster import Booster
|
|||
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
@ -99,6 +100,43 @@ def check_torch_fsdp_ckpt():
|
|||
outputs_sec = fsdp_model(inputs)
|
||||
assert criterion(outputs_sec) == criterion(outputs)
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optim_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
run_model()
|
||||
|
||||
booster.save_model(fsdp_model, model_ckpt_path, shard=True)
|
||||
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
|
||||
|
||||
full_msd = fsdp_model.unwrap().state_dict()
|
||||
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
import copy
|
||||
sharded_osd = copy.deepcopy(full_osd)
|
||||
|
||||
run_model()
|
||||
|
||||
full_msd_updated = fsdp_model.unwrap().state_dict()
|
||||
full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
# cost much time led to timeout
|
||||
# assert not compare_nested_dict(full_osd_updated, sharded_osd)
|
||||
# assert not compare_nested_dict(full_msd_updated, full_msd)
|
||||
outputs_first = fsdp_model(inputs)
|
||||
assert criterion(outputs_first) != criterion(outputs)
|
||||
|
||||
booster.load_model(fsdp_model, model_ckpt_path)
|
||||
booster.load_optimizer(optimizer, optim_ckpt_path)
|
||||
|
||||
full_msd_restore = fsdp_model.unwrap().state_dict()
|
||||
sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
|
||||
|
||||
assert compare_nested_dict(sharded_osd, sharded_osd_restore)
|
||||
assert compare_nested_dict(full_msd_restore, full_msd)
|
||||
outputs_sec = fsdp_model(inputs)
|
||||
assert criterion(outputs_sec) == criterion(outputs)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# init dist env
|
||||
|
|
|
@ -1,13 +1,22 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
|
||||
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
|
||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import get_moe_epsize_param_dict
|
||||
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
|
||||
|
||||
|
||||
def delete_moe_info(model):
|
||||
for _, param in model.named_parameters():
|
||||
if hasattr(param, "moe_info"):
|
||||
delattr(param, "moe_info")
|
||||
|
||||
|
||||
class MoeModel(nn.Module):
|
||||
|
@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
|
|||
for i in range(world_size - 1):
|
||||
a = tensor_list[i]
|
||||
b = tensor_list[i + 1]
|
||||
assert not torch.allclose(a, b), \
|
||||
(f"expected tensors on rank {i} and {i + 1} not to be equal "
|
||||
f"but they are, {a} vs {b}")
|
||||
assert not torch.allclose(a, b), (
|
||||
f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
|
||||
)
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if isinstance(model, LowLevelZeroModel):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
|
||||
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
||||
"""Sync the parameters of tp model from ep model
|
||||
|
||||
Args:
|
||||
local_model (MoeModule)
|
||||
ep_model (MoeModule)
|
||||
"""
|
||||
for (local_name, local_param), (ep_name, ep_param) in zip(
|
||||
local_model.named_parameters(), ep_model.named_parameters()
|
||||
):
|
||||
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
|
||||
if "experts" not in local_name:
|
||||
if assert_grad_flag:
|
||||
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
|
||||
assert torch.allclose(local_param.grad, ep_param.grad)
|
||||
else:
|
||||
local_param.data.copy_(ep_param.data)
|
||||
continue
|
||||
|
||||
# gather param from ep model
|
||||
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
||||
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
|
||||
all_param = torch.cat(param_list, dim=0)
|
||||
if assert_grad_flag:
|
||||
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
||||
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
|
||||
all_grad = torch.cat(grad_list, dim=0)
|
||||
|
||||
if assert_grad_flag:
|
||||
assert torch.allclose(local_param, all_param)
|
||||
assert torch.allclose(local_param.grad, all_grad)
|
||||
else:
|
||||
local_param.data.copy_(all_param.data)
|
||||
|
||||
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
rtol = None
|
||||
atol = None
|
||||
if dtype is torch.float16:
|
||||
rtol = 5e-2
|
||||
atol = 5e-4
|
||||
elif dtype is torch.bfloat16:
|
||||
rtol = 4e-3
|
||||
atol = 4e-3
|
||||
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype).to(a.device)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
|
|
@ -12,7 +12,6 @@ import colossalai
|
|||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||
|
||||
sys.path.append(
|
||||
|
@ -95,6 +94,7 @@ def get_model(parallel):
|
|||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=1,
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
|
@ -103,6 +103,7 @@ def get_model(parallel):
|
|||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=dist.get_world_size(),
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
|
@ -111,6 +112,7 @@ def get_model(parallel):
|
|||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=2,
|
||||
zero_stage=2,
|
||||
extra_dp_size=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
|
@ -120,6 +122,7 @@ def get_model(parallel):
|
|||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
ep_size=2,
|
||||
zero_stage=1,
|
||||
microbatch_size=1,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
|
@ -130,27 +133,6 @@ def get_model(parallel):
|
|||
|
||||
|
||||
def _test_moe_checkpoint(rank, parallel):
|
||||
if parallel == None:
|
||||
MOE_MANAGER.setup(
|
||||
parallel=None,
|
||||
)
|
||||
elif parallel == "ep":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
)
|
||||
elif parallel == "ep_zero":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=2,
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=1,
|
||||
fixed_ep_size=2,
|
||||
fixed_pp_size=2,
|
||||
)
|
||||
model1, booster1, optim1 = get_model(parallel)
|
||||
model2, booster2, optim2 = get_model(parallel)
|
||||
model3, booster3, optim3 = get_model(parallel)
|
||||
|
@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
|
|||
_test_moe_checkpoint(rank, parallel)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This is tested in ColossalMOE")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
|
||||
|
|
|
@ -4,15 +4,21 @@ import torch
|
|||
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["router", "num_groups"], [
|
||||
(Top1Router(), 1),
|
||||
(Top2Router(), 1),
|
||||
# (TopKRouter(num_selected_experts=3), 4),
|
||||
])
|
||||
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
|
||||
(4, 5, 8),
|
||||
(3, 4, 4),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
["router", "num_groups"],
|
||||
[
|
||||
(Top1Router(), 1),
|
||||
(Top2Router(), 1),
|
||||
# (TopKRouter(num_selected_experts=3), 4),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
["batch_size", "seq_len", "num_experts"],
|
||||
[
|
||||
(4, 5, 8),
|
||||
(3, 4, 4),
|
||||
],
|
||||
)
|
||||
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
|
||||
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
|
||||
if num_groups > 1:
|
||||
|
@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
|
|||
|
||||
router.train()
|
||||
if isinstance(router, TopKRouter):
|
||||
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
_, combine_array, dispatch_mask = router(x)
|
||||
combine_array, dispatch_mask = router(x)[1:3]
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
||||
router.eval()
|
||||
if isinstance(router, TopKRouter):
|
||||
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
_, combine_array, dispatch_mask = router(x)
|
||||
combine_array, dispatch_mask = router(x)[1:3]
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
|
|
@ -4,102 +4,75 @@ import torch
|
|||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
||||
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
|
||||
|
||||
|
||||
def split_ddp_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if isinstance(model, LowLevelZeroModel):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
|
||||
def run_zero_test(local_rank, world_size, stage=1):
|
||||
def run_zero_test(local_rank, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
zero_model = MoeModel()
|
||||
optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
||||
booster = Booster(plugin=plugin)
|
||||
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel="EP")
|
||||
moe_model = MoeModel().bfloat16()
|
||||
moe_optimizer = torch.optim.Adam(moe_model.parameters())
|
||||
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||
moe_booster = Booster(plugin=moe_plugin)
|
||||
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
|
||||
|
||||
torch_model = MoeModel()
|
||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
torch_param.data.copy_(zero_param.data)
|
||||
torch_model = torch_model.cuda()
|
||||
grad_handler = MoeGradientHandler(torch_model)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel=None)
|
||||
zero_model = MoeModel().bfloat16()
|
||||
delete_moe_info(zero_model)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||
zero_booster = Booster(plugin=zero_plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
|
||||
sync_local_from_ep(zero_model, moe_model)
|
||||
|
||||
# assert zero model
|
||||
for (torch_name, torch_param), (zero_name, zero_param) in zip(
|
||||
torch_model.named_parameters(), zero_model.module.named_parameters()
|
||||
):
|
||||
assert zero_name == torch_name
|
||||
assert torch.allclose(zero_param.data, torch_param.data)
|
||||
|
||||
data = torch.randn(16, 4).cuda()
|
||||
data = torch.randn(16, 4).bfloat16().cuda()
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
|
||||
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer)
|
||||
assert torch.allclose(torch_out, zero_out)
|
||||
grad_handler.handle_gradient()
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
|
||||
assert torch.allclose(zero_out, moe_out)
|
||||
|
||||
for (zero_name, zero_param), (torch_name, torch_param) in zip(
|
||||
zero_model.module.named_parameters(), torch_model.named_parameters()
|
||||
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||
moe_model.module.named_parameters(), zero_model.module.named_parameters()
|
||||
):
|
||||
assert zero_name == torch_name
|
||||
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
|
||||
if hasattr(zero_param, "moe_info"):
|
||||
assert len(zero_grad_list) == 0
|
||||
assert torch.allclose(zero_param.grad, torch_param.grad)
|
||||
assert moe_name == zero_name
|
||||
moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
|
||||
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
|
||||
if hasattr(moe_param, "moe_info"):
|
||||
assert len(moe_grad_list) == 0
|
||||
if stage == 1:
|
||||
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
|
||||
else:
|
||||
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
|
||||
assert torch.allclose(
|
||||
moe_param.grad, zero_grad, atol=1e-5
|
||||
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
|
||||
else:
|
||||
assert len(zero_grad_list) > 0
|
||||
torch_grad_list = split_ddp_grad(torch_param.grad, world_size)
|
||||
if stage == 2:
|
||||
torch_grad_list = torch_grad_list[local_rank : local_rank + 1]
|
||||
assert len(zero_grad_list) == len(torch_grad_list)
|
||||
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
|
||||
assert torch.allclose(zero_grad, torch_grad)
|
||||
assert len(moe_grad_list) > 0
|
||||
assert len(moe_grad_list) == len(zero_grad_list)
|
||||
for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
|
||||
assert torch.allclose(moe_grad, zero_grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
def run_dist(rank, world_size, port, stage):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_MANAGER.setup(parallel="EP")
|
||||
seed_all(42 + rank)
|
||||
run_zero_test(rank, world_size, stage=1)
|
||||
run_zero_test(rank, world_size, stage=2)
|
||||
run_zero_test(rank, stage=stage)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("stage", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_model(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
def test_moe_zero_model(world_size, stage):
|
||||
spawn(run_dist, world_size, stage=stage)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_model(world_size=2)
|
||||
test_moe_zero_model(world_size=2, stage=1)
|
||||
|
|
|
@ -4,89 +4,80 @@ import torch
|
|||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
|
||||
|
||||
|
||||
def split_ddp_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
if criterion:
|
||||
y = model(data)
|
||||
loss = criterion(y, label)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if isinstance(model, LowLevelZeroModel):
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
|
||||
|
||||
def run_zero_optim_test(local_rank, world_size, stage=1):
|
||||
def run_zero_test(local_rank, stage=1):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
zero_model = MoeModel()
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
||||
booster = Booster(plugin=plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel="EP")
|
||||
moe_model = MoeModel().bfloat16()
|
||||
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
|
||||
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||
moe_booster = Booster(plugin=moe_plugin)
|
||||
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
|
||||
|
||||
torch_model = MoeModel()
|
||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
||||
torch_param.data.copy_(zero_param.data)
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters())
|
||||
torch_model = torch_model.cuda()
|
||||
grad_handler = MoeGradientHandler(torch_model)
|
||||
MOE_MANAGER.__init__()
|
||||
MOE_MANAGER.setup(parallel=None)
|
||||
zero_model = MoeModel().bfloat16()
|
||||
delete_moe_info(zero_model)
|
||||
sync_local_from_ep(zero_model, moe_model)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
|
||||
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||
zero_booster = Booster(plugin=zero_plugin)
|
||||
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
|
||||
|
||||
for _ in range(2):
|
||||
data = torch.randn(16, 4).cuda() / (local_rank + 1)
|
||||
label = torch.randint(0, 4, (16,)).cuda()
|
||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
||||
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
grad_handler.handle_gradient()
|
||||
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||
moe_model.named_parameters(), zero_model.named_parameters()
|
||||
):
|
||||
if ".experts." in moe_name:
|
||||
continue
|
||||
assert moe_name == zero_name
|
||||
assert torch.allclose(
|
||||
moe_param.data, zero_param.data
|
||||
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
|
||||
|
||||
torch_optimizer.step()
|
||||
for _ in range(1):
|
||||
data = torch.randn(2, 4).bfloat16().cuda()
|
||||
label = torch.randint(0, 4, (2,)).cuda()
|
||||
|
||||
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
|
||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||
assert torch.allclose(zero_out, moe_out)
|
||||
moe_optimizer.step()
|
||||
zero_optimizer.step()
|
||||
|
||||
for (torch_name, torch_param), (zero_name, zero_param) in zip(
|
||||
torch_model.named_parameters(), zero_model.named_parameters()
|
||||
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||
moe_model.named_parameters(), zero_model.named_parameters()
|
||||
):
|
||||
assert torch.allclose(
|
||||
torch_param.data, zero_param.data
|
||||
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
|
||||
assert moe_name == zero_name
|
||||
if is_moe_tensor(moe_param):
|
||||
param_size = moe_param.shape[0]
|
||||
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
|
||||
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
|
||||
|
||||
torch_optimizer.zero_grad()
|
||||
moe_optimizer.zero_grad()
|
||||
zero_optimizer.zero_grad()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
def run_dist(rank, world_size, port, stage):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
MOE_MANAGER.setup(parallel="EP")
|
||||
run_zero_optim_test(rank, world_size, stage=1)
|
||||
run_zero_optim_test(rank, world_size, stage=2)
|
||||
seed_all(42 + rank)
|
||||
run_zero_test(rank, stage=stage)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("stage", [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_zero_optim(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
def test_moe_zero_optim(world_size, stage):
|
||||
spawn(run_dist, world_size, stage=stage)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_zero_optim(world_size=2)
|
||||
test_moe_zero_optim(world_size=2, stage=1)
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
||||
|
||||
def test_lr_scheduler_save_load():
|
||||
model = nn.Linear(10, 10)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
|
||||
new_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
|
||||
for _ in range(5):
|
||||
scheduler.step()
|
||||
state_dict = scheduler.state_dict()
|
||||
new_scheduler.load_state_dict(state_dict)
|
||||
assert state_dict == new_scheduler.state_dict()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lr_scheduler_save_load()
|
|
@ -1 +1 @@
|
|||
0.3.4
|
||||
0.3.5
|
||||
|
|
Loading…
Reference in New Issue