diff --git a/.github/workflows/release_nightly_on_schedule.yml b/.github/workflows/release_nightly_on_schedule.yml
index 4125f333f..072a943ae 100644
--- a/.github/workflows/release_nightly_on_schedule.yml
+++ b/.github/workflows/release_nightly_on_schedule.yml
@@ -6,11 +6,13 @@ on:
- cron: '0 0 * * 6' # release on every Sunday 00:00 UTC time
jobs:
- build-n-publish:
+ publish:
if: github.repository == 'hpcaitech/ColossalAI'
name: Build and publish Python 🐍 distributions 📦 to PyPI
runs-on: ubuntu-latest
timeout-minutes: 20
+ outputs:
+ status: ${{ steps.publish.outcome }}
steps:
- uses: actions/checkout@v2
@@ -18,7 +20,9 @@ jobs:
with:
python-version: '3.8.14'
- - run: NIGHTLY=1 python setup.py sdist build
+ - run: |
+ python .github/workflows/scripts/update_setup_for_nightly.py
+ python setup.py sdist build
# publish to PyPI if executed on the main branch
- name: Publish package to PyPI
@@ -31,7 +35,7 @@ jobs:
notify:
name: Notify Lark via webhook
- needs: build-n-publish
+ needs: publish
runs-on: ubuntu-latest
if: ${{ always() }} && github.repository == 'hpcaitech/ColossalAI'
steps:
@@ -62,4 +66,4 @@ jobs:
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
- STATUS: ${{ steps.publish.outcome }}
+ STATUS: ${{ needs.publish.outputs.status }}
diff --git a/.github/workflows/release_test_pypi_before_merge.yml b/.github/workflows/release_test_pypi_before_merge.yml
index 284ab4d1a..7af641fc3 100644
--- a/.github/workflows/release_test_pypi_before_merge.yml
+++ b/.github/workflows/release_test_pypi_before_merge.yml
@@ -49,6 +49,6 @@ jobs:
# we need to install the requirements.txt first
# as test-pypi may not contain the distributions for libs listed in the txt file
pip install -r requirements/requirements.txt
- pip install --index-url https://test.pypi.org/simple/ colossalai==$VERSION
+ pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.python.org/pypi colossalai==$VERSION
env:
VERSION: ${{ steps.prep-version.outputs.version }}
diff --git a/.github/workflows/scripts/update_setup_for_nightly.py b/.github/workflows/scripts/update_setup_for_nightly.py
new file mode 100644
index 000000000..d8a3087ef
--- /dev/null
+++ b/.github/workflows/scripts/update_setup_for_nightly.py
@@ -0,0 +1,34 @@
+from datetime import datetime
+
+
+def open_setup_file():
+ with open("setup.py", "r") as f:
+ file_lines = f.readlines()
+ return file_lines
+
+
+def replace_nightly_package_info(file_lines):
+ version = datetime.today().strftime("%Y.%m.%d")
+ package_name = "colossalai-nightly"
+
+ for idx, line in enumerate(file_lines):
+ if "version = get_version()" in line:
+ file_lines[idx] = f'version = "{version}"\n'
+ if 'package_name = "colossalai"' in line:
+ file_lines[idx] = f'package_name = "{package_name}"\n'
+ return file_lines
+
+
+def write_setup_file(file_lines):
+ with open("setup.py", "w") as f:
+ f.writelines(file_lines)
+
+
+def main():
+ file_lines = open_setup_file()
+ file_lines = replace_nightly_package_info(file_lines)
+ write_setup_file(file_lines)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/README.md b/README.md
index 13757eece..442e6bbcd 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
Documentation |
Examples |
Forum |
- Blog
+ Blog
[](https://github.com/hpcaitech/ColossalAI/stargazers)
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml)
@@ -398,10 +398,10 @@ pip install colossalai
**Note: only Linux is supported for now.**
-However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`.
+However, if you want to build the PyTorch extensions during installation, you can set `BUILD_EXT=1`.
```bash
-CUDA_EXT=1 pip install colossalai
+BUILD_EXT=1 pip install colossalai
```
**Otherwise, CUDA kernels will be built during runtime when you actually need them.**
@@ -429,7 +429,7 @@ By default, we do not compile CUDA/C++ kernels. ColossalAI will build them durin
If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
```shell
-CUDA_EXT=1 pip install .
+BUILD_EXT=1 pip install .
```
For Users with CUDA 10.2, you can still build ColossalAI from source. However, you need to manually download the cub library and copy it to the corresponding directory.
@@ -445,7 +445,7 @@ unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
# install
-CUDA_EXT=1 pip install .
+BUILD_EXT=1 pip install .
```
(back to top)
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index c0e257f54..e67e16231 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -49,12 +49,13 @@ def _preprocess(
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
- sequences = [s + t for s, t in zip(sources, targets)]
+ sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)]
sequences_token = tokenizer(
- sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
)
+
sources_token = tokenizer(
- sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
)
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
@@ -65,7 +66,8 @@ def _preprocess(
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
- labels[i][-pad_len:] = IGNORE_INDEX
+ if pad_len>0:
+ labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][: pad_len + source_len] = IGNORE_INDEX
diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh
index 0fb4da3d3..b7d176847 100755
--- a/applications/Chat/examples/train_sft.sh
+++ b/applications/Chat/examples/train_sft.sh
@@ -25,4 +25,4 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
- --max_epochs 1
+ --max_epochs 1
\ No newline at end of file
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
index a2cfb2ef6..327651f4e 100644
--- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
@@ -1,20 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-import numpy as np
import os
-import random
from dataclasses import dataclass
-from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
+from typing import Dict, Iterator, List, Optional, Sequence, Union
import torch
-from datasets import dataset_dict, load_from_disk
-from datasets import Dataset as HFDataset
-from torch.distributed import ProcessGroup
-from torch.distributed.distributed_c10d import _get_default_group
-from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
-from transformers.tokenization_utils import PreTrainedTokenizer
import torch.nn.functional as F
+from datasets import Dataset as HFDataset
+from datasets import dataset_dict, load_from_disk
+from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
+from transformers.tokenization_utils import PreTrainedTokenizer
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
@@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
+ padding: str = "max_length"
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
@@ -106,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
- # pad to max
- to_pad = self.max_length - input_ids.size(1)
- input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
- labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
+ if self.padding == "max_length":
+ # pad to max
+ to_pad = self.max_length - input_ids.size(1)
+ input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
+ labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
elif self.tokenizer.padding_side == "left":
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
@@ -171,49 +169,3 @@ class StatefulDistributedSampler(DistributedSampler):
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
-
-
-def setup_distributed_dataloader(
- dataset: DatasetType,
- batch_size: int = 1,
- shuffle: bool = False,
- seed: int = 1024,
- drop_last: bool = False,
- pin_memory: bool = False,
- num_workers: int = 0,
- collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
- process_group: Optional[ProcessGroup] = None,
- **kwargs,
-) -> DataLoader:
- """
- Setup dataloader for distributed training.
- """
- _kwargs = kwargs.copy()
- process_group = process_group or _get_default_group()
- sampler = StatefulDistributedSampler(
- dataset=dataset,
- num_replicas=process_group.size(),
- rank=process_group.rank(),
- shuffle=shuffle,
- seed=seed,
- drop_last=drop_last,
- )
-
- # Deterministic dataloader
- def seed_worker(worker_id: int) -> None:
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- return DataLoader(
- dataset=dataset,
- batch_size=batch_size,
- sampler=sampler,
- num_workers=num_workers,
- collate_fn=collate_fn,
- pin_memory=pin_memory,
- drop_last=drop_last,
- worker_init_fn=seed_worker,
- **_kwargs,
- )
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
index 1926ec78a..6c048c3b1 100644
--- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
@@ -1,15 +1,15 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
+import math
from types import MethodType
from typing import Optional, Tuple
import torch
+import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
-from flash_attn.bert_padding import pad_input, unpad_input
-from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
-from flash_attn.ops.rms_norm import rms_norm
+from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
@@ -19,194 +19,334 @@ from transformers.models.llama.modeling_llama import (
repeat_kv,
)
+from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
+if get_accelerator().name == "cuda":
+ from flash_attn.bert_padding import pad_input, unpad_input
+ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
+ from flash_attn.ops.rms_norm import rms_norm
-def _prepare_decoder_attention_mask(
- self: LlamaModel,
- attention_mask: torch.BoolTensor,
- input_shape: torch.Size,
- inputs_embeds: torch.Tensor,
- past_key_values_length: int,
-) -> Optional[torch.Tensor]:
- """
- Decoder attetion mask
- """
- if past_key_values_length > 0 and attention_mask is not None:
- attention_mask = torch.cat(
- tensors=(
- torch.full(
- size=(input_shape[0], past_key_values_length),
- fill_value=True,
- dtype=attention_mask.dtype,
- device=attention_mask.device,
+ def _prepare_decoder_attention_mask(
+ self: LlamaModel,
+ attention_mask: torch.BoolTensor,
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ) -> Optional[torch.Tensor]:
+ """
+ Decoder attetion mask
+ """
+ if past_key_values_length > 0 and attention_mask is not None:
+ attention_mask = torch.cat(
+ tensors=(
+ torch.full(
+ size=(input_shape[0], past_key_values_length),
+ fill_value=True,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ ),
+ attention_mask,
),
- attention_mask,
- ),
- dim=-1,
- ) # (bsz, past_key_values_length + q_len)
- if attention_mask is not None and torch.all(attention_mask):
- return None # Faster
- return attention_mask
-
-
-def attention_forward(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """
- Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
- """
- if output_attentions:
- logger.warning(
- "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
- "return `None` instead."
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- if self.config.pretraining_tp > 1:
- q_slicing, kv_slicing = (
- dim // self.config.pretraining_tp
- for dim in (
- self.num_heads * self.head_dim,
- self.num_key_value_heads * self.head_dim,
- )
- ) # `Tuple[int, int]`
- q_slices, k_slices, v_slices = (
- proj.weight.split(slicing, dim=0)
- for proj, slicing in (
- (self.q_proj, q_slicing),
- (self.k_proj, kv_slicing),
- (self.v_proj, kv_slicing),
- )
- ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
- q, k, v = (
- torch.cat(
- [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
+ ) # (bsz, past_key_values_length + q_len)
+ if attention_mask is not None and torch.all(attention_mask):
+ return None # Faster
+ return attention_mask
+
+ def attention_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
+ """
+ if output_attentions:
+ logger.warning(
+ "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
+ "return `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ q_slicing, kv_slicing = (
+ dim // self.config.pretraining_tp
+ for dim in (
+ self.num_heads * self.head_dim,
+ self.num_key_value_heads * self.head_dim,
+ )
+ ) # `Tuple[int, int]`
+ q_slices, k_slices, v_slices = (
+ proj.weight.split(slicing, dim=0)
+ for proj, slicing in (
+ (self.q_proj, q_slicing),
+ (self.k_proj, kv_slicing),
+ (self.v_proj, kv_slicing),
+ )
+ ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
+ q, k, v = (
+ torch.cat(
+ [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
+ dim=-1,
+ )
+ for slices in (q_slices, k_slices, v_slices)
+ )
+ # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
+ # (bsz, q_len, num_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim)
+ else:
+ q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
+ # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
+ # (bsz, q_len, num_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim)
+
+ # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
+ # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
+ # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
+ q, k, v = (
+ states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
+ for states, num_heads in (
+ (q, self.num_heads),
+ (k, self.num_key_value_heads),
+ (v, self.num_key_value_heads),
)
- for slices in (q_slices, k_slices, v_slices)
)
- # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
- # (bsz, q_len, num_heads * head_dim),
- # (bsz, q_len, num_key_value_heads * head_dim),
- # (bsz, q_len, num_key_value_heads * head_dim)
- else:
- q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
- # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
- # (bsz, q_len, num_heads * head_dim),
- # (bsz, q_len, num_key_value_heads * head_dim),
- # (bsz, q_len, num_key_value_heads * head_dim)
+ kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
+ past_kv_len = 0
+ if past_key_value is not None:
+ # if `past_key_value` is not None, `kv_len` > `q_len`.
+ past_kv_len = past_key_value[0].shape[-2]
+ kv_len += past_kv_len
- # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
- # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
- # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
- q, k, v = (
- states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
- for states, num_heads in (
- (q, self.num_heads),
- (k, self.num_key_value_heads),
- (v, self.num_key_value_heads),
- )
- )
- kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
- past_kv_len = 0
- if past_key_value is not None:
- # if `past_key_value` is not None, `kv_len` > `q_len`.
- past_kv_len = past_key_value[0].shape[-2]
- kv_len += past_kv_len
+ # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
+ cos, sin = self.rotary_emb(v, seq_len=kv_len)
+ # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
+ q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ k = torch.cat([past_key_value[0], k], dim=2)
+ v = torch.cat([past_key_value[1], v], dim=2)
- # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
- cos, sin = self.rotary_emb(v, seq_len=kv_len)
- # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
- q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
- if past_key_value is not None:
- # reuse k, v, self_attention
- k = torch.cat([past_key_value[0], k], dim=2)
- v = torch.cat([past_key_value[1], v], dim=2)
+ past_key_value = (k, v) if use_cache else None
- past_key_value = (k, v) if use_cache else None
+ # repeat k/v heads if n_kv_heads < n_heads
+ k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
+ # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
+ v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
+ # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
- # repeat k/v heads if n_kv_heads < n_heads
- k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
- # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
- v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
- # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
+ key_padding_mask = attention_mask
+ # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
+ q, k, v = (states.transpose(1, 2) for states in (q, k, v))
- key_padding_mask = attention_mask
- # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
- q, k, v = (states.transpose(1, 2) for states in (q, k, v))
-
- if past_kv_len > 0:
- q = torch.cat(
- tensors=(
- torch.full(
- size=(bsz, past_kv_len, self.num_heads, self.head_dim),
- fill_value=0.0,
- dtype=q.dtype,
- device=q.device,
+ if past_kv_len > 0:
+ q = torch.cat(
+ tensors=(
+ torch.full(
+ size=(bsz, past_kv_len, self.num_heads, self.head_dim),
+ fill_value=0.0,
+ dtype=q.dtype,
+ device=q.device,
+ ),
+ q,
),
- q,
- ),
- dim=1,
- ) # (bsz, past_kv_len + q_len, num_heads, head_dim)
+ dim=1,
+ ) # (bsz, past_kv_len + q_len, num_heads, head_dim)
- if key_padding_mask is None:
- # (bsz, past_kv_len + q_len, num_heads, head_dim)
- output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
- output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
- else:
- q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
- kv, _, cu_kv_lens, max_kv_len = unpad_input(
- hidden_states=torch.stack(tensors=(k, v), dim=2),
- attention_mask=key_padding_mask,
- )
- output_unpad = flash_attn_varlen_kvpacked_func(
- q=q,
- kv=kv,
- cu_seqlens_q=cu_q_lens,
- cu_seqlens_k=cu_kv_lens,
- max_seqlen_q=max_q_len,
- max_seqlen_k=max_kv_len,
- dropout_p=0.0,
- softmax_scale=None,
- causal=True,
- )
- output = pad_input(
- hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
- indices=indices,
- batch=bsz,
- seqlen=past_kv_len + q_len,
- ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
+ if key_padding_mask is None:
+ # (bsz, past_kv_len + q_len, num_heads, head_dim)
+ output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
+ output = rearrange(
+ output, pattern="... h d -> ... (h d)"
+ ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
+ else:
+ q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
+ kv, _, cu_kv_lens, max_kv_len = unpad_input(
+ hidden_states=torch.stack(tensors=(k, v), dim=2),
+ attention_mask=key_padding_mask,
+ )
+ output_unpad = flash_attn_varlen_kvpacked_func(
+ q=q,
+ kv=kv,
+ cu_seqlens_q=cu_q_lens,
+ cu_seqlens_k=cu_kv_lens,
+ max_seqlen_q=max_q_len,
+ max_seqlen_k=max_kv_len,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=True,
+ )
+ output = pad_input(
+ hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
+ indices=indices,
+ batch=bsz,
+ seqlen=past_kv_len + q_len,
+ ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
- if past_kv_len > 0:
- # Strip off the zero query outputs.
- output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
- output = self.o_proj(output) # (bsz, q_len, hidden_size)
- return output, None, past_key_value
+ if past_kv_len > 0:
+ # Strip off the zero query outputs.
+ output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
+ output = self.o_proj(output) # (bsz, q_len, hidden_size)
+ return output, None, past_key_value
+ def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Formard function for RMS Norm
+ """
+ return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
-def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
- """
- Formard function for RMS Norm
- """
- return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
+ def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaAttention):
+ module.forward = MethodType(attention_forward, module)
+ if isinstance(module, LlamaModel):
+ module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
+ if isinstance(module, LlamaRMSNorm):
+ module.forward = MethodType(rms_norm_forward, module)
+elif get_accelerator().name == "npu":
+ import torch_npu
-def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
- for name, module in model.named_modules():
- if isinstance(module, LlamaAttention):
- module.forward = MethodType(attention_forward, module)
- if isinstance(module, LlamaModel):
- module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
- if isinstance(module, LlamaRMSNorm):
- module.forward = MethodType(rms_norm_forward, module)
+ class NPULlamaAttention(LlamaAttention):
+ use_flash: bool = True
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.setup()
+
+ def setup(self):
+ self._softmax_scale = 1 / math.sqrt(self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split(
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+ )
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if not self.use_flash:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+ else:
+ attn_output, *_ = torch_npu.npu_fusion_attention(
+ query_states,
+ key_states,
+ value_states,
+ self.num_heads,
+ "BNSD",
+ atten_mask=attention_mask.bool(),
+ scale=self._softmax_scale,
+ padding_mask=None,
+ pre_tockens=65535,
+ next_tockens=0,
+ keep_prob=1.0,
+ inner_precise=0,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum(
+ [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
+ )
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ class NPURMSNorm(LlamaRMSNorm):
+ def forward(self, hidden_states):
+ return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
+
+ def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaAttention):
+ module.__class__ = NPULlamaAttention
+ module.setup()
+ if isinstance(module, LlamaRMSNorm):
+ module.__class__ = NPURMSNorm
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
index 9f6c9c1cc..21d769f3c 100644
--- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
@@ -17,7 +17,7 @@ import torch
def unwrap(model):
if hasattr(model, "module"):
- return unwrap_model(model.module)
+ return model.unwrap()
else:
return model
diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA-2/inference_example.py
index 7fe2d92ab..63ce91e50 100644
--- a/applications/Colossal-LLaMA-2/inference_example.py
+++ b/applications/Colossal-LLaMA-2/inference_example.py
@@ -1,22 +1,21 @@
import argparse
-import os
import torch
+from colossal_llama2.dataset.conversation import default_conversation
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
from colossalai.logging import get_dist_logger
-from transformers import AutoTokenizer, AutoModelForCausalLM
logger = get_dist_logger()
def load_model(model_path, device="cuda", **kwargs):
- logger.info(
- "Please check whether the tokenizer and model weights are properly stored in the same folder."
- )
+ logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.")
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
model.to(device)
try:
- tokenizer = AutoTokenizer.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
except OSError:
raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
@@ -27,31 +26,51 @@ def load_model(model_path, device="cuda", **kwargs):
def generate(args):
model, tokenizer = load_model(model_path=args.model_path, device=args.device)
- BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
- input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
+ if args.prompt_style == "sft":
+ conversation = default_conversation.copy()
+ conversation.append_message("Human", args.input_txt)
+ conversation.append_message("Assistant", None)
+ input_txt = conversation.get_prompt()
+ else:
+ BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
+ input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
- inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device)
- output = model.generate(**inputs,
- max_new_tokens=args.max_new_tokens,
- do_sample=args.do_sample,
- temperature=args.temperature,
- top_k=args.top_k,
- top_p=args.top_p,
- num_return_sequences=1)
- response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):]
- logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
+ inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
+ num_input_tokens = inputs["input_ids"].shape[-1]
+ output = model.generate(
+ **inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.do_sample,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ top_p=args.top_p,
+ num_return_sequences=1,
+ )
+ response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
+ logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
- parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model")
- parser.add_argument('--device', type=str, default="cuda:0", help="Set the device")
- parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt")
- parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling")
- parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value")
- parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering")
- parser.add_argument('--top_p', type=int, default=0.95, help="Set top_p value for generation")
- parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model")
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="hpcai-tech/Colossal-LLaMA-2-7b-base",
+ help="HF repo name or local path of the model",
+ )
+ parser.add_argument("--device", type=str, default="cuda:0", help="Set the device")
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=512,
+ help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
+ )
+ parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling")
+ parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
+ parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
+ parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
+ parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model")
+ parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt")
args = parser.parse_args()
- generate(args)
\ No newline at end of file
+ generate(args)
diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA-2/requirements.txt
index d8afee768..34afaf7e5 100644
--- a/applications/Colossal-LLaMA-2/requirements.txt
+++ b/applications/Colossal-LLaMA-2/requirements.txt
@@ -1,9 +1,9 @@
torch<2.0.0, >=1.12.1
packaging==23.1
-colossalai==0.3.2
+colossalai==0.3.5
autoflake==2.2.1
black==23.9.1
-transformers
+transformers==4.33.3
tensorboard==2.14.0
six==1.16.0
datasets
diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh
index 276d9ce99..6a1c887bf 100644
--- a/applications/Colossal-LLaMA-2/train.example.sh
+++ b/applications/Colossal-LLaMA-2/train.example.sh
@@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.
--warmup_steps 100 \
--use_grad_checkpoint \
--use_flash_attn \
+ --pad_token "unk"
diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py
index 92863e8e4..2e4bab75a 100644
--- a/applications/Colossal-LLaMA-2/train.py
+++ b/applications/Colossal-LLaMA-2/train.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
-Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
+Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team
"""
import argparse
@@ -16,22 +16,24 @@ from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
- setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
+from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
-from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
+from transformers import LlamaForCausalLM, LlamaTokenizer
import colossalai
+from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
def get_model_numel(model: torch.nn.Module) -> int:
@@ -83,6 +85,7 @@ def main() -> None:
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
@@ -108,6 +111,12 @@ def main() -> None:
default=False,
help="Use flash-attention",
)
+ parser.add_argument(
+ "--use_neft",
+ action="store_true",
+ default=False,
+ help="Use NEFTune",
+ )
parser.add_argument(
"--freeze_non_embeds_params",
action="store_true",
@@ -116,6 +125,8 @@ def main() -> None:
)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--zero", type=int, default=1)
+ parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
+ parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
args = parser.parse_args()
with open(args.config_file, "w") as f:
@@ -125,6 +136,7 @@ def main() -> None:
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
+ accelerator = get_accelerator()
coordinator = DistCoordinator()
# ==============================
@@ -142,6 +154,7 @@ def main() -> None:
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
+ enable_gradient_accumulation=(args.accumulation_steps > 1),
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -149,6 +162,7 @@ def main() -> None:
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
+ enable_gradient_accumulation=(args.accumulation_steps > 1),
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -182,7 +196,10 @@ def main() -> None:
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
- tokenizer.pad_token = tokenizer.unk_token
+ if args.pad_token == "eos":
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.pad_token == "unk":
+ tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
@@ -193,38 +210,36 @@ def main() -> None:
coordinator.print_on_master(f"Load dataset: {args.dataset}")
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
- dataloader = setup_distributed_dataloader(
+ data_collator = DataCollatorForSupervisedDataset(
+ tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
+ )
+ dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.micro_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
+ distributed_sampler_cls=StatefulDistributedSampler,
)
coordinator.print_on_master(
- f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
-
- # colossalai has changed api for get_current_device in 0.3.4 version or newer
- try:
- from colossalai.accelerator import get_accelerator
-
- current_device = get_accelerator().get_current_device()
- except:
- from colossalai.utils import get_current_device
-
- current_device = get_current_device()
-
- init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ init_ctx = (
+ LazyInitContext(default_device=get_current_device())
+ if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
+ else nullcontext()
+ )
with init_ctx:
- model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
+ model = LlamaForCausalLM.from_pretrained(args.pretrained)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)
+ # this is essential, otherwise the grad checkpoint will not work.
+ model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
@@ -246,12 +261,14 @@ def main() -> None:
adamw_mode=True,
)
+ if args.warmup_steps is None:
+ args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
+ coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
+
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
- total_steps=args.num_epochs * len(dataloader),
- warmup_steps=args.warmup_steps
- if args.warmup_steps is not None
- else int(args.num_epochs * len(dataloader) * 0.025),
+ total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
+ warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
@@ -267,11 +284,9 @@ def main() -> None:
torch.set_default_dtype(torch.float)
- if args.load_checkpoint is None:
- coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
- booster.load_model(model, args.pretrained, strict=False)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
@@ -298,85 +313,109 @@ def main() -> None:
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
- num_steps_per_epoch = len(dataloader)
+ if args.use_neft:
+ coordinator.print_on_master("Activate NEFTune.")
+ model, handle = activate_neftune(model)
+
+ num_steps_per_epoch = len(dataloader) // args.accumulation_steps
# If resume training, set the sampler start index to the correct value
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
- with tqdm(
- iterable=enumerate(dataloader, start=start_step),
+ pbar = tqdm(
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
- initial=start_step,
- ) as pbar:
- for step, batch in pbar:
- batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
+ initial=start_step // args.accumulation_steps,
+ )
+ total_loss = torch.tensor(0.0, device=get_current_device())
+ for step, batch in enumerate(dataloader, start=start_step):
+ batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
- batch_output = model(**batch)
+ batch_output = model(**batch)
- loss = batch_output.loss
+ loss = batch_output.loss / args.accumulation_steps
+ total_loss.add_(loss.data)
- booster.backward(loss=loss, optimizer=optimizer)
+ booster.backward(loss=loss, optimizer=optimizer)
+ if (step + 1) % args.accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
- all_reduce_mean(tensor=loss)
- pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
+ all_reduce_mean(tensor=total_loss)
+ pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
if coordinator.is_master():
- global_step = epoch * num_steps_per_epoch + step
- writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
+ global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
+ writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
- # Save modeling.
+ total_loss.fill_(0.0)
+ pbar.update()
+ # Save modeling.
- if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
- coordinator.print_on_master("\nStart saving model checkpoint with running states")
- save_checkpoint(
- save_dir=args.save_dir,
- booster=booster,
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- epoch=epoch,
- step=step + 1,
- batch_size=args.micro_batch_size,
- coordinator=coordinator,
- )
- coordinator.print_on_master(
- f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
- )
+ if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
+ step + 1
+ ) == len(dataloader):
+ coordinator.print_on_master("\nStart saving model checkpoint with running states")
- # Delete CUDA cache.
- # del batch, batch_labels, batch_output, loss
- torch.cuda.empty_cache()
+ if args.use_neft:
+ coordinator.print_on_master("Deactivate NEFTune before saving model.")
+ deactivate_neftune(model, handle)
+
+ accelerator.empty_cache()
+ save_checkpoint(
+ save_dir=args.save_dir,
+ booster=booster,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ step=step + 1,
+ batch_size=args.micro_batch_size,
+ coordinator=coordinator,
+ )
+ coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
+ )
+
+ if args.use_neft:
+ coordinator.print_on_master("Activate NEFTune.")
+ model, handle = activate_neftune(model)
+
+ # Delete cache.
+ # del batch, batch_labels, batch_output, loss
+ accelerator.empty_cache()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
start_step = 0
+ if args.use_neft:
+ coordinator.print_on_master("Deactivate NEFTune.")
+ deactivate_neftune(model, handle)
+
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+ coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
diff --git a/applications/Colossal-LLaMA-2/train_sft.example.sh b/applications/Colossal-LLaMA-2/train_sft.example.sh
index dcb11515d..d87f9ef82 100755
--- a/applications/Colossal-LLaMA-2/train_sft.example.sh
+++ b/applications/Colossal-LLaMA-2/train_sft.example.sh
@@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
-colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \
+colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
--pretrained $PRETRAINED_MODEL_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
@@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_
--use_grad_checkpoint \
--use_flash_attn \
--use_neft \
+ --pad_token "eos"
diff --git a/applications/Colossal-LLaMA-2/train_sft.py b/applications/Colossal-LLaMA-2/train_sft.py
deleted file mode 100644
index fd9e1cd3e..000000000
--- a/applications/Colossal-LLaMA-2/train_sft.py
+++ /dev/null
@@ -1,403 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-"""
-Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
-"""
-
-import argparse
-import json
-import os
-import resource
-from contextlib import nullcontext
-
-import torch
-import torch.distributed as dist
-from colossal_llama2.dataset.loader import (
- DataCollatorForSupervisedDataset,
- StatefulDistributedSampler,
- load_tokenized_dataset,
- setup_distributed_dataloader,
-)
-from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
-from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
-from colossal_llama2.utils.froze import freeze_non_embeds_parameters
-from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
-from torch.utils.tensorboard import SummaryWriter
-from tqdm import tqdm
-from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.lazy import LazyInitContext
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
-
-
-def get_model_numel(model: torch.nn.Module) -> int:
- return sum(p.numel() for p in model.parameters())
-
-
-def format_numel_str(numel: int) -> str:
- B = 1024**3
- M = 1024**2
- K = 1024
- if numel >= B:
- return f"{numel / B:.2f} B"
- elif numel >= M:
- return f"{numel / M:.2f} M"
- elif numel >= K:
- return f"{numel / K:.2f} K"
- else:
- return f"{numel}"
-
-
-def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
- dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
- tensor.div_(dist.get_world_size())
- return tensor
-
-
-def main() -> None:
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--pretrained",
- type=str,
- default=None,
- help="Address of the pre-trained modeling",
- )
- parser.add_argument("--dataset", nargs="+", default=[])
- parser.add_argument(
- "--plugin",
- type=str,
- default="gemini",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
- help="Choose which plugin to use",
- )
- parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
- parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
- parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
- parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
- parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps")
- parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
- parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
- parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default="fp16",
- choices=["fp16", "bf16"],
- help="Mixed precision",
- )
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
- parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
- parser.add_argument(
- "--use_grad_checkpoint",
- action="store_true",
- default=False,
- help="Use gradient checkpointing",
- )
- parser.add_argument(
- "--use_flash_attn",
- action="store_true",
- default=False,
- help="Use flash-attention",
- )
- parser.add_argument(
- "--use_neft",
- action="store_true",
- default=False,
- help="Use NEFTune",
- )
- parser.add_argument(
- "--freeze_non_embeds_params",
- action="store_true",
- default=False,
- help="Freeze non embeddings parameters",
- )
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--zero", type=int, default=1)
- args = parser.parse_args()
-
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
-
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch({})
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Tensorboard
- # ==============================
- if coordinator.is_master():
- os.makedirs(args.tensorboard_dir, exist_ok=True)
- writer = SummaryWriter(args.tensorboard_dir)
-
- # ==============================
- # Initialize Booster
- # ==============================
- if args.plugin == "gemini":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- cpu_offload=True,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "3d":
- plugin = HybridParallelPlugin(
- tp_size=args.tp,
- pp_size=1,
- zero_stage=args.zero,
- max_norm=args.grad_clip,
- precision=args.mixed_precision,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
-
- # ======================================================
- # Initialize Tokenizer, Dataset, Collator and Dataloader
- # ======================================================
- tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
- tokenizer.pad_token = tokenizer.eos_token
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
-
- coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
- coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
- coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
-
- coordinator.print_on_master(f"Load dataset: {args.dataset}")
-
- dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
- dataloader = setup_distributed_dataloader(
- dataset=dataset,
- batch_size=args.micro_batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=data_collator,
- )
- coordinator.print_on_master(
- f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
-
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- init_ctx = (
- LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- )
- with init_ctx:
- model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
- # Freeze part of parameters.
- if args.freeze_non_embeds_params:
- freeze_non_embeds_parameters(model=model)
-
- if args.use_grad_checkpoint:
- model.gradient_checkpointing_enable()
- coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
- if args.use_flash_attn:
- replace_with_flash_attention(model=model)
- coordinator.print_on_master(msg="Flash-attention enabled successfully")
-
- model_numel = get_model_numel(model)
- coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
-
- optimizer = HybridAdam(
- model_params=filter(lambda p: p.requires_grad, model.parameters())
- if args.freeze_non_embeds_params
- else model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
- if args.warmup_steps is None:
- args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
- coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
-
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer=optimizer,
- total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
- warmup_steps=args.warmup_steps,
- eta_min=0.1 * args.lr,
- )
-
- # Flash attention will be disabled because it does NOT support fp32.
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optimizer, _, dataloader, lr_scheduler = booster.boost(
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- dataloader=dataloader,
- )
-
- torch.set_default_dtype(torch.float)
-
- if args.load_checkpoint is None:
- coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
- booster.load_model(model, args.pretrained, strict=False)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- start_epoch = 0
- start_step = 0
- sampler_start_idx = 0
- if args.load_checkpoint is not None:
- if "modeling" in args.load_checkpoint:
- coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
- booster.load_model(model, args.load_checkpoint)
- else:
- coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
- start_epoch, start_step, sampler_start_idx = load_checkpoint(
- load_dir=args.load_checkpoint,
- booster=booster,
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- )
- coordinator.print_on_master(
- f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
- )
- coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
-
- coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- if args.use_neft:
- coordinator.print_on_master("Activate NEFTune.")
- model, handle = activate_neftune(model)
-
- num_steps_per_epoch = len(dataloader) // args.accumulation_steps
- # If resume training, set the sampler start index to the correct value
- assert isinstance(dataloader.sampler, StatefulDistributedSampler)
- dataloader.sampler.set_start_index(start_index=sampler_start_idx)
-
- for epoch in range(start_epoch, args.num_epochs):
- dataloader.sampler.set_epoch(epoch=epoch)
- pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
- total_loss = torch.tensor(0.0).to(torch.cuda.current_device())
- for step, batch in enumerate(dataloader):
- batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
-
- batch_output = model(**batch)
-
- loss = batch_output.loss / args.accumulation_steps
- total_loss += loss.item()
-
- booster.backward(loss=loss, optimizer=optimizer)
-
- if (step + 1) % args.accumulation_steps == 0:
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
-
- all_reduce_mean(tensor=total_loss)
- pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
- if coordinator.is_master():
- global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
- writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
- writer.add_scalar(
- tag="Learning Rate",
- scalar_value=lr_scheduler.get_last_lr()[0],
- global_step=global_step,
- )
- total_loss.fill_(0.0)
- pbar.update()
- # Save modeling.
-
- if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
- step + 1
- ) == len(dataloader):
- coordinator.print_on_master("\nStart saving model checkpoint with running states")
-
- if args.use_neft:
- coordinator.print_on_master("Deactivate NEFTune before saving model.")
- deactivate_neftune(model, handle)
-
- save_checkpoint(
- save_dir=args.save_dir,
- booster=booster,
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- epoch=epoch,
- step=step + 1,
- batch_size=args.micro_batch_size,
- coordinator=coordinator,
- )
- coordinator.print_on_master(
- f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
- )
-
- if args.use_neft:
- coordinator.print_on_master("Activate NEFTune.")
- model, handle = activate_neftune(model)
-
- # Delete CUDA cache.
- # del batch, batch_labels, batch_output, loss
- torch.cuda.empty_cache()
-
- # the continue epochs are not resumed, so we need to reset the sampler start index and start step
- dataloader.sampler.set_start_index(start_index=0)
- start_step = 0
-
- if args.use_neft:
- coordinator.print_on_master("Deactivate NEFTune.")
- deactivate_neftune(model, handle)
-
- # Final save.
- coordinator.print_on_master("Start saving final model checkpoint")
- booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
- coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
-
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- main()
diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py
index f293c4f69..9c70c0d2a 100644
--- a/applications/ColossalEval/colossal_eval/models/chatglm.py
+++ b/applications/ColossalEval/colossal_eval/models/chatglm.py
@@ -3,6 +3,8 @@ from typing import List
import torch
+from colossalai.utils import get_current_device
+
from .huggingface import HuggingFaceModel
IGNORE_INDEX = -100
@@ -126,9 +128,9 @@ class ChatGLMModel(HuggingFaceModel):
"""
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
- ).to(torch.cuda.current_device())
+ ).to(get_current_device())
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
- torch.cuda.current_device()
+ get_current_device()
)
outputs = self.model(input_ids)[0]
@@ -197,7 +199,7 @@ class ChatGLM2Model(ChatGLMModel):
truncation=True,
return_tensors="pt",
max_length=self.model_max_length - max_new_tokens,
- ).to(torch.cuda.current_device())
+ ).to(get_current_device())
# Set output_scores=True to get prediction scores.
outputs = self.model.generate(
diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py
index 741c884f0..fff697e21 100644
--- a/applications/ColossalEval/colossal_eval/models/huggingface.py
+++ b/applications/ColossalEval/colossal_eval/models/huggingface.py
@@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokeni
from colossalai.logging import DistributedLogger
from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.utils import get_current_device
from .base import BaseModel
@@ -128,12 +129,12 @@ class HuggingFaceModel(BaseModel):
self.model = AutoModel.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model)
- self.model.to(torch.cuda.current_device())
+ self.model.to(get_current_device())
if peft_path is not None:
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
else:
- self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
+ self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
self.model.eval()
@@ -155,11 +156,11 @@ class HuggingFaceModel(BaseModel):
"""
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
- ).to(torch.cuda.current_device())
+ ).to(get_current_device())
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
- torch.cuda.current_device()
+ get_current_device()
)
- attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device())
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
@@ -464,7 +465,7 @@ class HuggingFaceModel(BaseModel):
return_tensors="pt",
return_token_type_ids=False,
max_length=self.model_max_length - max_new_tokens,
- ).to(torch.cuda.current_device())
+ ).to(get_current_device())
# Set output_scores=True to get prediction scores.
outputs = self.model.generate(
@@ -598,12 +599,12 @@ class HuggingFaceCausalLM(HuggingFaceModel):
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model)
- self.model.to(torch.cuda.current_device())
+ self.model.to(get_current_device())
if peft_path is not None:
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
else:
- self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
+ self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py
index 5b09f9de8..a340f3bfd 100644
--- a/applications/ColossalEval/examples/dataset_evaluation/inference.py
+++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py
@@ -8,6 +8,7 @@ import torch.distributed as dist
from colossal_eval import dataset, models, utils
import colossalai
+from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig
@@ -82,6 +83,7 @@ def rm_and_merge(
def main(args):
colossalai.launch_from_torch(config={}, seed=42)
+ accelerator = get_accelerator()
world_size = dist.get_world_size()
rank = dist.get_rank()
@@ -235,10 +237,10 @@ def main(args):
),
)
- logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
+ logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB")
del model_
- torch.cuda.empty_cache()
+ accelerator.empty_cache()
dist.barrier()
if rank == 0:
diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md
new file mode 100644
index 000000000..be50a8f9f
Binary files /dev/null and b/applications/ColossalMoE/README.md differ
diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
new file mode 100644
index 000000000..d08dfd5f8
--- /dev/null
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
@@ -0,0 +1,629 @@
+import copy
+import logging
+import os
+from pathlib import Path
+from shutil import rmtree
+from typing import Dict, Iterator, Optional, OrderedDict, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.checkpoint_io import CheckpointIndexFile
+from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
+from colossalai.checkpoint_io.index_file import CheckpointIndexFile
+from colossalai.checkpoint_io.utils import (
+ StateDictSharder,
+ gather_distributed_param,
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ load_shard_state_dict,
+ load_states_into_optimizer,
+ save_config_file,
+ save_param_groups,
+ save_state_dict_shards,
+ search_tp_partition_dim,
+ sharded_optimizer_loading_epilogue,
+)
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.moe import MOE_MANAGER
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
+
+try:
+ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
+except ImportError:
+ _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
+
+
+class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
+ def __init__(
+ self,
+ dp_group: ProcessGroup,
+ pp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ zero_stage: int,
+ verbose: bool = True,
+ ) -> None:
+ super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
+ moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
+ self.ep_group = moe_info.ep_group
+ self.ep_size = moe_info.ep_size
+ self.ep_rank = moe_info.ep_rank
+ self.real_dp_rank = moe_info.dp_rank
+
+ @staticmethod
+ def _model_sharder(
+ model: nn.Module,
+ prefix: str = "",
+ keep_vars: bool = False,
+ size_per_shard: int = 1024,
+ param_name_pattern: Optional[str] = None,
+ ) -> Iterator[Tuple[OrderedDict, int]]:
+ # An internel method that breaks state_dict of model into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+
+ # Save parameters.
+ for name, param in model.named_parameters():
+ if param is None:
+ continue
+ if param_name_pattern is not None and param_name_pattern not in name:
+ continue
+ # Gather tensor pieces when using tensor parallel.
+ param_ = gather_distributed_param(param, keep_vars=False)
+ block, block_size = state_dict_sharder.append_param(prefix + name, param_)
+ if block is not None:
+ yield block, block_size
+
+ # Save buffers.
+ for name, buf in model.named_buffers():
+ if buf is not None and name not in model._non_persistent_buffers_set:
+ buffer = buf if keep_vars else buf.detach()
+ block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
+ if block is not None:
+ yield block, block_size
+
+ # Save extra states.
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if (
+ getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+ is not torch.nn.Module.get_extra_state
+ ):
+ extra_state = model.get_extra_state()
+ block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def save_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
+ """
+ Save sharded model checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
+ - Multiple files that store state tensors of models.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_model.-000XX.bin"
+
+
+ Args:
+ model (nn.Module): Model on local device to be saved.
+ checkpoint (str): Checkpointing path which should be a directory path.
+ gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
+ prefix (str, optional): Perfix of file to save. Defaults to None.
+ size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
+ use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
+ """
+
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+ model = model.unwrap()
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ if self.real_dp_rank != 0:
+ dist.barrier()
+ return
+
+ # ep_rank 0 saves all the parameters and buffers.
+ # other ep_ranks save only experts
+ ep_param_pattern = "experts." if self.ep_rank != 0 else None
+
+ # Then collect the sharded parameters & buffers along tp_group.
+ # Only devices with tp_rank == 0 are responsible for model saving.
+ state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
+ model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
+ )
+ weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = self.tp_rank == 0
+
+ if self.pp_size == 1 and self.ep_size == 1:
+ # When pipeline is not used, save the model shards as in general checkpointIO
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ )
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ dist.barrier()
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
+ weights_name = weights_name.replace(
+ ".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors"
+ )
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ use_pp_format=True,
+ )
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ dist.barrier()
+ return
+
+ dist.barrier()
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.coordinator.is_master():
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for weight, weight_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(weight, weight_filename)
+
+ final_index_file.write_index_file(final_index_file_path)
+ save_config_file(model, checkpoint)
+ rmtree(tmp_index_file_folder)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
+
+ @staticmethod
+ def gather_from_sharded_optimizer_state(
+ state: OrderedDict,
+ param: torch.Tensor,
+ original_shape: torch.Size,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ use_zero: bool,
+ inplace: bool,
+ is_moe_param: bool,
+ device: torch.device = torch.device("cpu"),
+ ) -> OrderedDict:
+ """
+ With given parameter and its optimizer states, gather the complete optimizer state for saving.
+
+ Args:
+ state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
+ param (torch.Tensor): The given parameter. It should be working_param when using Zero.
+ original_shape (torch.Size): The size of parameter before sharding.
+ dp_group (ProcessGroup): The process group of data parallel.
+ tp_group (ProcessGroup): The process group of tensor parallel.
+ use_zero (bool): Whether Zero is used.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+ device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
+
+ Returns:
+ OrderedDict: The complete optimizer state of given parameter.
+ """
+ dp_size = dist.get_world_size(dp_group)
+ tp_size = dist.get_world_size(tp_group)
+ current_shape = param.shape
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ # First gather Zero shards.
+ if use_zero and not is_moe_param:
+ v = v.cuda()
+ gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
+ dist.all_gather(gather_tensor, v, group=dp_group)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+
+ # Then gather TP shards.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
+ if partition_dim is not None:
+ gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
+ dist.all_gather(gather_tensor, v, group=tp_group)
+ v = torch.cat(gather_tensor, dim=partition_dim)
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
+
+ @staticmethod
+ def _optimizer_sharder(
+ optimizer: OptimizerWrapper,
+ use_zero: bool,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ size_per_shard: int = 1024,
+ only_moe_param: bool = False,
+ ):
+ # An internel method that breaks state_dict of optimizer into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+ param_info = optimizer.param_info
+ master_to_working_map = optimizer.get_master_to_working_map()
+
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ param_id = param_info["param2id"][id(working_param)]
+ original_shape = param_info["param2shape"][id(working_param)]
+ state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+ state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=dp_group,
+ tp_group=tp_group,
+ use_zero=use_zero,
+ inplace=False,
+ is_moe_param=is_moe_tensor(working_param),
+ )
+
+ if only_moe_param and not is_moe_tensor(working_param):
+ continue
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
+ """
+ Save sharded optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files that store state tensors of optimizers.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_optim.-000XX.bin"
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+ checkpoint (str): Path to save optimizer state_dict
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ prefix (str): Perfix of file to save
+ size_per_shard (int): Max file size of each file shard that store state tensors
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of states when zero is not used.
+ # In this case only let the device with dp_rank == 0 save the model.
+ if not self.use_zero and self.real_dp_rank != 0:
+ dist.barrier()
+ return
+
+ # Then collect the sharded states along dp_group(if using zero)/tp_group.
+ # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
+ state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
+ optimizer,
+ use_zero=self.use_zero,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ size_per_shard=size_per_shard,
+ only_moe_param=self.ep_rank != 0,
+ )
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
+
+ if self.pp_size == 1 and self.ep_size == 1:
+ # When pipeline is not used, save the optimizer shards as in general checkpointIO
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ )
+
+ if control_saving:
+ # Store param groups.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ save_param_groups({"param_groups": param_groups}, group_file_path)
+ # Store index file.
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ dist.barrier()
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ use_pp_format=True,
+ )
+
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ dist.barrier()
+ return
+
+ dist.barrier()
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.coordinator.is_master():
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for param_id, state_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(param_id, state_filename)
+
+ # Store param groups.
+ final_index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ save_param_groups({"param_groups": param_groups}, group_file_path)
+
+ final_index_file.write_index_file(final_index_file_path)
+ rmtree(tmp_index_file_folder)
+
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
+
+ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
+ """
+ Load sharded optimizer with the given path to index file of checkpoint folder.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ prefix (str): Not used.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ return optimizer.param_info["param2id"][id(working_param)]
+
+ # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
+ # When Zero is used, the mapped parameter objects should be fp32 master parameters.
+ # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
+ id_map = {}
+ master_to_working_map = optimizer.get_master_to_working_map()
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ id_map[param_id] = param
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
+ saved_groups = torch.load(param_group_path)
+
+ updated_groups = []
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ # obtain updated param group
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
+ updated_groups.append(new_pg)
+ # ep param groups
+ if len(optimizer.optim.param_groups) == len(saved_groups) + 1:
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+ # Load saved states to optimizer.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ if param is None:
+ continue
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ if param_id not in weight_map:
+ continue
+ filename = weight_map[param_id]
+
+ # If this param's states has been loaded before, directly return.
+ if filename in loaded_file:
+ continue
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
+ load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
+ loaded_file.add(filename)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ device = param.device
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(
+ state,
+ current_shape=working_param.shape,
+ original_shape=original_shape,
+ device=device,
+ inplace=True,
+ is_moe_param=is_moe_tensor(working_param),
+ )
+ optimizer.optim.state[param] = sharded_state
+
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def shard_from_complete_optimizer_state(
+ self,
+ state: OrderedDict,
+ current_shape: torch.Size,
+ original_shape: torch.Size,
+ device: torch.device,
+ inplace: bool,
+ is_moe_param: bool,
+ ) -> OrderedDict:
+ """
+ With complete optimizer states of a specific parameter loaded from checkpoint,
+ slice out the sharded optimizer states kept by current device.
+
+ Args:
+ state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
+ current_shape (torch.Size): The size of parameter after sharding.
+ original_shape (torch.Size): The size of parameter before sharding.
+ device (torch.device): The destination device of loaded optimizer states.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+
+ Returns:
+ OrderedDict: The sharded optimizer state of the given parameter.
+ """
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ # Shard state along tensor parallel group.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
+ if partition_dim is not None:
+ slice_size = current_shape[partition_dim]
+ v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
+
+ # Shard state along data parallel group when using Zero.
+ if self.use_zero and not is_moe_param:
+ padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ slice_size = v.numel() // self.dp_size
+ v = v.split(slice_size, dim=0)[self.dp_rank]
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
+
+ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ raise NotImplementedError
+
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
+ raise NotImplementedError
+
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
+ raise NotImplementedError
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
new file mode 100644
index 000000000..a2b78a2bd
--- /dev/null
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
@@ -0,0 +1,92 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+from colossalai.lazy import LazyInitContext
+from colossalai.moe import MOE_MANAGER
+from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
+
+
+class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
+ def __init__(self, config):
+ super().__init__(config)
+ self.setup_ep()
+
+ def setup_ep(self):
+ _, moe_info = MOE_MANAGER.get_info(self.num_experts)
+ ep_group = moe_info.ep_group
+ self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
+ self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+ assert self.num_experts % self.ep_size == 0
+ self.ep_group = ep_group
+ self.num_experts_per_ep = self.num_experts // self.ep_size
+ self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
+ held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+ set_tensors_to_none(self.experts, exclude=set(held_experts))
+ for p in self.experts.parameters():
+ set_moe_tensor_info(p, moe_info)
+
+ @staticmethod
+ def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
+ LazyInitContext.materialize(module)
+ module.__class__ = EPMixtralSparseMoeBlock
+ module.setup_ep()
+ return module
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.gate(hidden_states)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ selected_experts = selected_experts.t().reshape(-1)
+ selected_experts_idx = selected_experts.argsort()
+ dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
+ input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
+ output_split_sizes = torch.zeros_like(input_split_sizes)
+ dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
+
+ input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+ output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+ output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
+ # compute expert output
+ output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+ if output_states.size(0) > 0:
+ if self.num_experts_per_ep == 1:
+ # no need to split
+ expert = self.experts[self.expert_start_idx]
+ output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
+ output_states = expert.w2(output_states)
+ else:
+ output_states_splits = output_states.split(output_split_sizes.tolist())
+ output_states_list = []
+ for i, split_states in enumerate(output_states_splits):
+ if split_states.size(0) == 0:
+ continue
+ expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+ split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
+ split_states = expert.w2(split_states)
+ output_states_list.append(split_states)
+ output_states = torch.cat(output_states_list)
+ output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+ dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
+ recover_experts_idx = torch.empty_like(selected_experts_idx)
+ recover_experts_idx[selected_experts_idx] = torch.arange(
+ selected_experts_idx.size(0), device=selected_experts_idx.device
+ )
+ dispatch_states = dispatch_states[recover_experts_idx]
+ k_hidden_states = dispatch_states.chunk(self.top_k)
+ output_states = k_hidden_states[0] * routing_weights[:, 0, None]
+ for i in range(1, self.top_k):
+ output_states += k_hidden_states[i] * routing_weights[:, i, None]
+ output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
+ return output_states, router_logits
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
new file mode 100644
index 000000000..218b05b27
--- /dev/null
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
@@ -0,0 +1,557 @@
+from functools import partial
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import CrossEntropyLoss, Module
+from transformers.models.mixtral.modeling_mixtral import (
+ MixtralDecoderLayer,
+ MixtralForCausalLM,
+ MixtralModel,
+ MoeCausalLMOutputWithPast,
+ _prepare_4d_causal_attention_mask,
+ load_balancing_loss_func,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from colossalai.shardformer.shard import ShardConfig
+
+from .mixtral_layer import EPMixtralSparseMoeBlock
+
+__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
+
+
+class MixtralPolicy(Policy):
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ if self.shard_config.enable_tensor_parallelism:
+ # Resize embedding
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
+
+ return self.model
+
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ policy = {}
+
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ raise NotImplementedError(
+ "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
+ )
+
+ if self.shard_config.enable_tensor_parallelism:
+ raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
+
+ # expert parallel
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="block_sparse_moe",
+ target_module=EPMixtralSparseMoeBlock,
+ )
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="norm",
+ target_module=FusedRMSNorm,
+ ),
+ policy=policy,
+ target_key=MixtralModel,
+ )
+
+ if self.shard_config.enable_flash_attention:
+ raise NotImplementedError("Flash attention has already been replaced in mixtral.")
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "MixtralModel":
+ module = self.model
+ else:
+ module = self.model.model
+
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=model_cls
+ )
+
+ return
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "MixtralModel":
+ module = self.model
+ else:
+ module = self.model.model
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+
+ return held_layers
+
+
+class MixtralModelPolicy(MixtralPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=MixtralModel,
+ new_forward=MixtralPipelineForwards.mixtral_model_forward,
+ policy=policy,
+ )
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama model"""
+ return []
+
+
+class MixtralForCausalLMPolicy(MixtralPolicy):
+ def module_policy(self):
+ policy = super().module_policy()
+
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for casual lm
+ new_item = {
+ MixtralForCausalLM: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs=dict(gather_output=True),
+ )
+ ]
+ )
+ }
+ policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=MixtralForCausalLM,
+ new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
+ policy=policy,
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ llama_model = self.model.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if (
+ id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
+ and self.pipeline_stage_manager.num_stages > 1
+ ):
+ # tie weights
+ return [
+ {
+ 0: llama_model.embed_tokens.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
+ }
+ ]
+ return []
+
+
+class MixtralPipelineForwards:
+ """
+ This class serves as a micro library for forward function substitution of Llama models
+ under pipeline setting.
+ """
+
+ @staticmethod
+ def mixtral_model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ past_router_logits: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
+
+ >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if stage_manager.is_first_stage():
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
+ use_cache = False
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions, for the first stage, hidden_states is the input embeddings,
+ # for the other stages, hidden_states is the output of the previous stage
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+ next_decoder_cache = None
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ output_attentions,
+ output_router_logits,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ output_attentions,
+ output_router_logits,
+ use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+ if output_router_logits:
+ all_router_logits += (layer_outputs[-1],)
+
+ if stage_manager.is_last_stage():
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ next_cache = next_decoder_cache if use_cache else None
+
+ if output_router_logits and past_router_logits is not None:
+ all_router_logits = past_router_logits + all_router_logits
+ if stage_manager.is_last_stage():
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ # always return dict for imediate stage
+ return {
+ "hidden_states": hidden_states,
+ "past_router_logits": all_router_logits,
+ }
+
+ @staticmethod
+ def mixtral_for_causal_lm_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ past_router_logits: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
+
+ >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = MixtralPipelineForwards.mixtral_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ past_router_logits=past_router_logits,
+ )
+ past_key_values = None
+
+ if stage_manager.is_last_stage():
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ if output_router_logits:
+ output = (aux_loss,) + output
+ return (loss,) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=None,
+ hidden_states=outputs[0],
+ attentions=None,
+ router_logits=outputs[-1],
+ )
+ else:
+ out = {}
+ hidden_states = outputs.get("hidden_states")
+ out["hidden_states"] = hidden_states
+ if output_router_logits:
+ out["past_router_logits"] = outputs["past_router_logits"]
+ return out
diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py
new file mode 100644
index 000000000..a2a0a7e78
--- /dev/null
+++ b/applications/ColossalMoE/colossal_moe/utils.py
@@ -0,0 +1,84 @@
+import json
+import os
+from typing import Any, Dict, Tuple, Union
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.optim.optimizer import Optimizer
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+
+
+def move_to_cuda(batch, device):
+ return {k: v.to(device) for k, v in batch.items()}
+
+
+def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
+ """
+ Load file in JSON format
+ """
+ with open(file=file_path, mode="r", encoding="utf-8") as fp:
+ return json.load(fp)
+
+
+def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
+ """
+ Save as JSON format
+ """
+ with open(file=file_path, mode="w", encoding="utf-8") as fp:
+ json.dump(data, fp=fp, ensure_ascii=False, indent=4)
+
+
+def save_checkpoint(
+ save_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+ epoch: int,
+ step: int,
+ batch_size: int,
+ coordinator: DistCoordinator,
+) -> None:
+ """
+ Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
+ os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
+
+ booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
+ booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
+ running_states = {
+ "epoch": epoch,
+ "step": step,
+ "sample_start_index": step * batch_size,
+ }
+ if coordinator.is_master():
+ save_json(running_states, os.path.join(save_dir, "running_states.json"))
+
+
+def load_checkpoint(
+ load_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+) -> Tuple[int, int, int]:
+ """
+ Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ # Update booster params states.
+ booster.load_model(model, os.path.join(load_dir, "modeling"))
+ booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
+ booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
+
+ running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
+ return (
+ running_states["epoch"],
+ running_states["step"],
+ running_states["sample_start_index"],
+ )
diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py
new file mode 100644
index 000000000..46ff70ff3
--- /dev/null
+++ b/applications/ColossalMoE/infer.py
@@ -0,0 +1,111 @@
+import argparse
+
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
+from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
+from transformers import AutoTokenizer
+from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+
+
+def parse_args():
+ # basic settings
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="mistralai/Mixtral-8x7B-v0.1",
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="ep",
+ choices=["ep"],
+ help="Parallel methos.",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ default="bf16",
+ choices=["fp32", "bf16", "fp16"],
+ help="The mixed precision training.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+
+ # kernel
+ parser.add_argument(
+ "--use_kernel",
+ action="store_true",
+ help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
+ )
+ parser.add_argument(
+ "--use_layernorm_kernel",
+ action="store_true",
+ help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+
+ config = MixtralConfig.from_pretrained(args.model_name)
+ ep_size = min(dist.get_world_size(), config.num_local_experts)
+ # Set plugin
+ if args.plugin == "ep":
+ plugin = MoeHybridParallelPlugin(
+ tp_size=1,
+ pp_size=1,
+ ep_size=ep_size,
+ zero_stage=1,
+ precision=args.precision,
+ custom_policy=MixtralForCausalLMPolicy(),
+ checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
+ enable_fused_normalization=args.use_layernorm_kernel,
+ enable_jit_fused=args.use_kernel,
+ )
+ else:
+ raise ValueError(f"Invalid plugin {args.plugin}")
+ coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
+
+ # Build mixtral model
+ model = MixtralForCausalLM.from_pretrained(args.model_name)
+ coordinator.print_on_master(f"Finish load model")
+
+ # Prepare tokenizer and dataloader
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+
+ # Set booster
+ booster = Booster(plugin=plugin)
+ model, _, _, _, _ = booster.boost(model=model)
+ coordinator.print_on_master(f"Finish init booster")
+
+ model.eval()
+
+ if coordinator.rank == 0:
+ text = ["Hello my name is"]
+ else:
+ text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
+ tokenizer.pad_token = tokenizer.unk_token
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
+
+ with torch.no_grad():
+ outputs = model.module.generate(**inputs, max_new_tokens=20)
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ print(f"[{coordinator.rank}] {outputs}")
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh
new file mode 100644
index 000000000..0487fe9c1
--- /dev/null
+++ b/applications/ColossalMoE/infer.sh
@@ -0,0 +1,7 @@
+NUM_GPU=2
+MODEL="mistralai/Mixtral-8x7B-v0.1"
+
+# ep
+torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
+ --model_name $MODEL \
+ --plugin "ep" \
diff --git a/applications/ColossalMoE/requirements.txt b/applications/ColossalMoE/requirements.txt
new file mode 100644
index 000000000..9a5738c41
--- /dev/null
+++ b/applications/ColossalMoE/requirements.txt
@@ -0,0 +1,5 @@
+colossalai >= 0.3.3
+torch >= 1.8.1
+transformers == 4.36.0
+sentencepiece
+datasets
diff --git a/applications/ColossalMoE/setup.py b/applications/ColossalMoE/setup.py
new file mode 100644
index 000000000..275f59e10
--- /dev/null
+++ b/applications/ColossalMoE/setup.py
@@ -0,0 +1,43 @@
+from setuptools import find_packages, setup
+
+
+def fetch_requirements(path):
+ with open(path, "r") as fd:
+ return [r.strip() for r in fd.readlines()]
+
+
+def fetch_readme():
+ with open("README.md", encoding="utf-8") as f:
+ return f.read()
+
+
+def fetch_version():
+ with open("version.txt", "r") as f:
+ return f.read().strip()
+
+
+setup(
+ name="colossal_moe",
+ version=fetch_version(),
+ packages=find_packages(
+ exclude=(
+ "tests",
+ "benchmarks",
+ "*.egg-info",
+ )
+ ),
+ description="Colossal-AI MoE",
+ long_description=fetch_readme(),
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ url="https://github.com/hpcaitech",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: System :: Distributed Computing",
+ ],
+)
diff --git a/applications/ColossalMoE/tests/__init__.py b/applications/ColossalMoE/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py
new file mode 100644
index 000000000..57589ab20
--- /dev/null
+++ b/applications/ColossalMoE/tests/test_mixtral_layer.py
@@ -0,0 +1,63 @@
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
+from torch.testing import assert_close
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+import colossalai
+from colossalai.moe import MOE_MANAGER
+from colossalai.testing.utils import spawn
+
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
+
+
+def check_mixtral_moe_layer():
+ torch.cuda.set_device(dist.get_rank())
+ MOE_MANAGER.setup(
+ parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
+ )
+ config = MixtralConfig(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_local_experts=n_experts,
+ num_experts_per_tok=top_k,
+ )
+ torch.manual_seed(0)
+ orig_model = MixtralSparseMoeBlock(config).cuda()
+ x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
+ orig_output, orig_logits = orig_model(x)
+ model = deepcopy(orig_model)
+ model = EPMixtralSparseMoeBlock.from_native_module(model)
+ ep_output, ep_logits = model(x)
+ assert_close(orig_logits, ep_logits)
+ assert_close(orig_output, ep_output)
+ orig_loss = orig_output.mean()
+ orig_loss.backward()
+ ep_loss = ep_output.mean()
+ ep_loss.backward()
+ assert_close(orig_loss, ep_loss)
+ name_to_p = {n: p for n, p in orig_model.named_parameters()}
+ for n, ep_p in model.named_parameters():
+ p = name_to_p[n]
+ if ep_p.grad is not None:
+ assert_close(p.grad, ep_p.grad)
+
+
+def run_dist(rank: int, world_size: int, port: int):
+ colossalai.launch({}, rank, world_size, "localhost", port)
+ check_mixtral_moe_layer()
+
+
+@pytest.mark.parametrize("world_size", [2, 4])
+def test_mixtral_moe_layer(world_size: int):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_mixtral_moe_layer(2)
diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py
new file mode 100644
index 000000000..822e7410f
--- /dev/null
+++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py
@@ -0,0 +1,146 @@
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
+from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
+from torch.optim import Adam
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing.utils import spawn
+
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
+
+
+def check_model_equal(model1, model2):
+ assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
+ for p1, p2 in zip(model1.parameters(), model2.parameters()):
+ assert torch.equal(p1.half(), p2.half())
+
+
+def get_optimizer_snapshot(optim):
+ state = {id(k): deepcopy(v) for k, v in optim.state.items()}
+ param_groups = []
+ for group in optim.param_groups:
+ params = [id(p) for p in group["params"]]
+ new_group = {"params": params}
+ for k, v in group.items():
+ if k != "params":
+ new_group[k] = v
+ param_groups.append(new_group)
+ return {
+ "state": state,
+ "param_groups": param_groups,
+ }
+
+
+def check_optimizer_snapshot_equal(snapshot1, snapshot2):
+ # check param_groups
+ assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
+ for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
+ assert set(group1.keys()) == set(group2.keys())
+ for k in group1.keys():
+ assert group1[k] == group2[k]
+ # check state
+ assert set(snapshot1["state"].keys()) == set(
+ snapshot2["state"].keys()
+ ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
+ for pid in snapshot1["state"].keys():
+ state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
+ assert set(state1.keys()) == set(state2.keys())
+ for k in state1.keys():
+ if isinstance(state1[k], torch.Tensor):
+ assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
+ else:
+ assert state1[k] == state2[k]
+
+
+def check_mixtral_moe_layer():
+ torch.cuda.set_device(dist.get_rank())
+ config = MixtralConfig(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_local_experts=n_experts,
+ num_experts_per_tok=top_k,
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ )
+ torch.manual_seed(0)
+ input_ids = torch.randint(0, 100, (2, tokens)).cuda()
+ orig_model = MixtralForCausalLM(config).cuda()
+ model = deepcopy(orig_model)
+ optimizer = Adam(model.parameters(), lr=1e-3)
+ plugin = MoeHybridParallelPlugin(
+ tp_size=1,
+ pp_size=2,
+ ep_size=2,
+ custom_policy=MixtralForCausalLMPolicy(),
+ checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
+ microbatch_size=1,
+ zero_stage=1,
+ )
+ booster = Booster(plugin=plugin)
+ model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
+ # initialize grads
+ data_iter = iter(
+ [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
+ )
+ booster.execute_pipeline(
+ data_iter,
+ model,
+ lambda outputs, inputs: outputs.loss,
+ optimizer,
+ )
+
+ # check save model
+ booster.save_model(model, "mixtral_model", shard=True)
+ dist.barrier()
+ if dist.get_rank() == 0:
+ saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
+ check_model_equal(orig_model, saved_model)
+ saved_model.save_pretrained("mixtral_hf_model")
+ dist.barrier()
+
+ # check load model
+ new_model = MixtralForCausalLM(config).cuda()
+ new_optimizer = Adam(new_model.parameters(), lr=1e-3)
+ new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
+ booster.load_model(new_model, "mixtral_hf_model")
+ check_model_equal(model, new_model)
+
+ # check save optimizer
+ optimizer.step()
+ for group in optimizer.param_groups:
+ group["lr"] = 0.1
+ snapshot = get_optimizer_snapshot(optimizer.unwrap())
+ booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
+ dist.barrier()
+ # reset optimizer state
+ for state in optimizer.unwrap().state.values():
+ for v in state.values():
+ if isinstance(v, torch.Tensor):
+ v.zero_()
+ booster.load_optimizer(optimizer, "mixtral_optim")
+ loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
+ check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
+
+
+def run_dist(rank: int, world_size: int, port: int):
+ colossalai.launch({}, rank, world_size, "localhost", port)
+ check_mixtral_moe_layer()
+
+
+@pytest.mark.parametrize("world_size", [4])
+def test_mixtral_moe_layer(world_size: int):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_mixtral_moe_layer(4)
diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py
new file mode 100644
index 000000000..c567038ec
--- /dev/null
+++ b/applications/ColossalMoE/train.py
@@ -0,0 +1,295 @@
+import argparse
+
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
+from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
+from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
+from torch.utils.data import Dataset
+from tqdm import tqdm
+from transformers import AutoTokenizer
+from transformers.models.mixtral import MixtralForCausalLM
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+
+@torch.no_grad()
+def get_global_loss(loss, booster):
+ global_loss = loss.clone().detach()
+ dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
+ global_loss.div_(booster.plugin.dp_size)
+ return global_loss
+
+
+class RandomDataset(Dataset):
+ def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
+ self.num_samples = num_samples
+ self.max_length = max_length
+ self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
+ self.attention_mask = torch.ones_like(self.input_ids)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ return {
+ "input_ids": self.input_ids[idx],
+ "attention_mask": self.attention_mask[idx],
+ "labels": self.input_ids[idx],
+ }
+
+
+def parse_args():
+ # basic settings
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="mistralai/Mixtral-8x7B-v0.1",
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="hybrid",
+ choices=["hybrid"],
+ help="Parallel methods.",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ default="./outputs",
+ help="The path of your saved model after finetuning.",
+ )
+ parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="Batch size (per dp group) for the training dataloader.",
+ )
+ parser.add_argument(
+ "--save_interval",
+ type=int,
+ default=1000,
+ help=" The interval (steps) of saving checkpoints.",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ default="bf16",
+ choices=["fp32", "bf16", "fp16"],
+ help="The mixed precision training.",
+ )
+ parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+
+ # optim
+ parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
+
+ # lr scheduler
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+
+ # zero stage for all plugins
+ parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
+ # hybrid plugin
+ parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
+ parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
+ parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
+ parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
+
+ # kernel
+ parser.add_argument(
+ "--use_kernel",
+ action="store_true",
+ help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
+ )
+ parser.add_argument(
+ "--use_layernorm_kernel",
+ action="store_true",
+ help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
+ )
+
+ # load balance
+ parser.add_argument(
+ "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
+ )
+ parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
+ # communicate overlap
+ parser.add_argument(
+ "--comm_overlap",
+ action="store_true",
+ help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
+ )
+ # hierarchical all-to-all
+ parser.add_argument(
+ "--hierarchical_alltoall",
+ action="store_true",
+ help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+
+ # Set plugin
+ if args.plugin == "hybrid":
+ plugin = MoeHybridParallelPlugin(
+ tp_size=1,
+ pp_size=args.pp_size,
+ ep_size=args.ep_size,
+ microbatch_size=args.microbatch_size,
+ custom_policy=MixtralForCausalLMPolicy(),
+ enable_fused_normalization=args.use_layernorm_kernel,
+ enable_jit_fused=args.use_kernel,
+ precision=args.precision,
+ zero_stage=args.zero_stage,
+ checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
+ )
+
+ else:
+ raise ValueError(f"Invalid plugin {args.plugin}")
+ coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
+
+ # Build Mixtral model
+ model = MixtralForCausalLM.from_pretrained(args.model_name)
+ coordinator.print_on_master(f"Finish init model")
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
+ # Prepare tokenizer and dataloader
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+ dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
+ collate_fn = None
+ dataloader = plugin.prepare_dataloader(
+ dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
+ )
+
+ # Set optimizer
+ optimizer = HybridAdam(
+ model_params=model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ # Set lr scheduler
+ lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=optimizer,
+ total_steps=args.num_epochs * len(dataloader),
+ warmup_steps=args.warmup_steps
+ if args.warmup_steps is not None
+ else int(args.num_epochs * len(dataloader) * 0.025),
+ eta_min=0.1 * args.lr,
+ )
+
+ # Set booster
+ booster = Booster(plugin=plugin)
+ model, optimizer, _, dataloader, lr_scheduler = booster.boost(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ dataloader=dataloader,
+ )
+ use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ coordinator.print_on_master(f"Finish init booster")
+
+ # Load ckpt
+ if args.load_checkpoint is not None:
+ load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
+ coordinator.print_on_master(f"Finish load optimizer")
+
+ # Start finetuning
+ coordinator.print_on_master(f"Start finetuning")
+ for epoch in range(args.num_epoch):
+ model.train()
+ train_dataloader_iter = iter(dataloader)
+ total_len = len(train_dataloader_iter)
+ with tqdm(
+ range(total_len),
+ desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
+ disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
+ ) as pbar:
+ for step in pbar:
+ if use_pipeline:
+ # Forward pass
+ outputs = booster.execute_pipeline(
+ train_dataloader_iter,
+ model,
+ lambda x, y: x.loss,
+ optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs["loss"]
+ global_loss = get_global_loss(loss, booster)
+ if coordinator._local_rank == "0":
+ pbar.set_postfix({"Loss": global_loss.item()})
+ else:
+ # Forward pass
+ data = next(train_dataloader_iter)
+ data = move_to_cuda(data, torch.cuda.current_device())
+ outputs = model(**data)
+ loss = outputs["loss"]
+ # Backward
+ booster.backward(loss, optimizer)
+ pbar.set_postfix({"loss": loss.item()})
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Apply load balance
+ # if (
+ # args.load_balance
+ # and args.load_balance_interval > 0
+ # and (step + 1) % args.load_balance_interval == 0
+ # ):
+ # coordinator.print_on_master(f"Apply load balance")
+ # apply_load_balance(model, optimizer)
+ # save ckeckpoint
+ if (step + 1) % args.save_interval == 0:
+ coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
+ save_checkpoint(
+ args.output_path,
+ booster,
+ model,
+ optimizer,
+ lr_scheduler,
+ epoch,
+ step,
+ args.batch_size,
+ coordinator,
+ )
+
+ # save checkpoint at the end of each epochs
+ booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
+ coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
+
+ # Finish training
+ coordinator.print_on_master(f"Finish training")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/ColossalMoE/train.sh b/applications/ColossalMoE/train.sh
new file mode 100644
index 000000000..bee7f5c8f
--- /dev/null
+++ b/applications/ColossalMoE/train.sh
@@ -0,0 +1,19 @@
+NUM_GPU=8
+MODEL="mistralai/Mixtral-8x7B-v0.1"
+SEQ_LENGTH=2048
+BATCH_SIZE=1
+LR=0.00001
+
+# hybrid
+# torchrun --standalone --nproc_per_node $NUM_GPU \
+colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \
+ train.py \
+ --num_epoch 1 \
+ --model_name $MODEL \
+ --plugin "hybrid" \
+ --batch_size $BATCH_SIZE \
+ --lr $LR \
+ --zero_stage 1 \
+ --pp_size 2 \
+ --dp_size 1 \
+ --ep_size 8 \
diff --git a/applications/ColossalMoE/version.txt b/applications/ColossalMoE/version.txt
new file mode 100644
index 000000000..3eefcb9dd
--- /dev/null
+++ b/applications/ColossalMoE/version.txt
@@ -0,0 +1 @@
+1.0.0
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
index 0b7b51a71..7439ad5d3 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
@@ -4,9 +4,9 @@ import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
+from colossalai.auto_parallel.tensor_shard.constants import BCAST_FUNC_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
-from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
__all__ = ["binary_elementwise_meta_info"]
diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py
index d2dd00453..27285f95c 100644
--- a/colossalai/booster/plugin/dp_plugin_base.py
+++ b/colossalai/booster/plugin/dp_plugin_base.py
@@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
self.world_size = dist.get_world_size()
def prepare_dataloader(
- self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ self,
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ distributed_sampler_cls=None,
+ **kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
- sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
+ distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
+ sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index d14109dd4..95b96bbfd 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
return ["cuda", "npu"]
def prepare_dataloader(
- self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ self,
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ distributed_sampler_cls=None,
+ **kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
- sampler = DistributedSampler(
+ distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
+ sampler = distributed_sampler_cls(
dataset,
num_replicas=zero_world_size * extra_dp_world_size,
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 5837156a9..da67e6b41 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1,5 +1,6 @@
import ctypes
import random
+import warnings
from contextlib import contextmanager
from functools import partial
from types import MethodType
@@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_process_group=self.tp_group,
)
else:
- assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
+ if self.dp_size == 1:
+ warnings.warn(
+ "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
+ "If you are not intended to use cpu_offload, please consider set zero_stage=0."
+ )
+
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,
@@ -1199,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
return outputs
def prepare_dataloader(
- self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ self,
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ distributed_sampler_cls=None,
+ **kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@@ -1223,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
- sampler = DistributedSampler(
+ distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
+ sampler = distributed_sampler_cls(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
)
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index e976d0aaf..45e5a23c1 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.moe import MoECheckpintIO
+from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self,
tp_size: int,
pp_size: int,
+ ep_size: int,
extra_dp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
@@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
+ checkpoint_io: Optional[MoECheckpintIO] = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
@@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
-
+ assert (
+ dist.get_world_size() % (tp_size * pp_size) == 0
+ ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+ assert (
+ dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
+ ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
+ self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
+ MOE_MANAGER.setup(
+ parallel="EP",
+ mode="fixed",
+ fixed_dp_size=self.real_dp_size,
+ fixed_ep_size=ep_size,
+ fixed_pp_size=pp_size,
+ use_ep_inside=use_ep_inside,
+ )
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ self.ep_size = ep_size
+ self.moe_info = MOE_MANAGER.get_info(0)[1]
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
@@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
+ self.checkpoint_io = checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
@@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
def get_checkpoint_io(self) -> MoECheckpintIO:
- self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ if self.checkpoint_io is None:
+ self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ else:
+ self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def configure(
diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py
index 2ea7593a5..5445b4a63 100644
--- a/colossalai/booster/plugin/torch_fsdp_plugin.py
+++ b/colossalai/booster/plugin/torch_fsdp_plugin.py
@@ -1,3 +1,5 @@
+import logging
+import os
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
@@ -25,7 +27,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
-from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
+from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
@@ -74,17 +76,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def save_sharded_model(
self,
- model: nn.Module,
- checkpoint: str,
- gather_dtensor: bool,
- prefix: Optional[str],
- size_per_shard: int,
- use_safetensors: bool,
+ model: ModelWrapper,
+ checkpoint_path: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
):
"""
Save model to checkpoint but only on master process.
"""
- raise NotImplementedError("Sharded model checkpoint is not supported yet.")
+ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
+ if os.path.isfile(checkpoint_path):
+ logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
+ return
+
+ Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
+ with FSDP.state_dict_type(
+ model.unwrap(),
+ StateDictType.FULL_STATE_DICT,
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ ):
+ state_dict = model.unwrap().state_dict()
+
+ state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
+
+ weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint_path)
+
+ # In general cases, is_master is set to True to get the right behavior.
+ total_size = utils.save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=self.coordinator.is_master(),
+ use_safetensors=use_safetensors,
+ )
+
+ # only save the index file on the master rank
+ if self.coordinator.is_master():
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ utils.save_config_file(model.unwrap(), checkpoint_path)
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
def load_sharded_model(
self,
@@ -97,7 +136,24 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
Load model to checkpoint but only on master process.
"""
- raise NotImplementedError("Sharded model checkpoint is not supported yet.")
+ assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
+ use_safetensors = False
+ if "safetensors" in checkpoint_index_file.name:
+ use_safetensors = True
+
+ if use_safetensors and not utils.is_safetensors_available():
+ raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
+
+ # read checkpoint index file
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+
+ fsdp_state_dict = {}
+ for shard_file in checkpoint_files:
+ fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
+
+ with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
+ model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
@@ -105,13 +161,86 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
Save optimizer to checkpoint but only on master process.
"""
- raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
+ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ with FSDP.state_dict_type(
+ optimizer.unwrap_model().unwrap(),
+ StateDictType.FULL_STATE_DICT,
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ ):
+ fsdp_optim_state = FSDP.full_optim_state_dict(
+ optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True
+ )
+
+ if self.coordinator.is_master():
+ # Preparing file paths and index file.
+ states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ utils.save_param_groups(fsdp_optim_state, group_file_path)
+
+ sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
+
+ # Save shards of optimizer states.
+ # In general cases, is_master is set to True to get the right behavior.
+ total_size = utils.save_state_dict_shards(
+ sharded_state_dict=sharded_state,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=self.coordinator.is_master(),
+ use_safetensors=False,
+ )
+
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""
- raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
+ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
+
+ ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(
+ f"Invalid index file path {index_file_path} for an optimizer. "
+ "Looking param group file under current directory."
+ )
+
+ saved_param_groups = torch.load(param_group_path)
+
+ # Load param
+ fsdp_optim_state = {}
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+ for shard_file in checkpoint_files:
+ state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ fsdp_optim_state.update(state_dict_shard)
+
+ fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
+
+ with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
+ fsdp_state = FSDP.optim_state_dict_to_load(
+ model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
+ )
+ optimizer.load_state_dict(fsdp_state)
+
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
@@ -190,7 +319,7 @@ class TorchFSDPPlugin(DPPluginBase):
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
def support_no_sync(self) -> bool:
- False
+ return False
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py
index 780117598..712324215 100644
--- a/colossalai/checkpoint_io/checkpoint_io_base.py
+++ b/colossalai/checkpoint_io/checkpoint_io_base.py
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
-from .utils import has_index_file
+from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"]
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
- self.load_unsharded_model(model, checkpoint, strict)
+ path = Path(checkpoint, SAFE_WEIGHTS_NAME)
+ if path.is_file():
+ self.load_unsharded_model(model, str(path), strict)
+ else:
+ path = Path(checkpoint, WEIGHTS_NAME)
+ if path.is_file():
+ self.load_unsharded_model(model, str(path), strict)
+ else:
+ self.load_unsharded_model(model, checkpoint, strict)
return origin_model
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index b7900bc0f..36df30335 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -1,7 +1,7 @@
import copy
-from functools import reduce
import logging
import os
+from functools import reduce
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.utils import get_current_device
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -445,7 +446,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
- save_param_groups(optimizer.param_info, group_file_path)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ save_param_groups({"param_groups": param_groups}, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
@@ -504,7 +509,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
- save_param_groups(optimizer.param_info, group_file_path)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ save_param_groups({"param_groups": param_groups}, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
@@ -713,12 +722,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group=self.tp_group,
use_zero=self.use_zero,
inplace=False,
- device=torch.device("cuda"),
+ device=get_current_device(),
)
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
- state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
@@ -729,7 +742,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Only the master rank do the saving.
if self.coordinator.is_master():
- state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
@@ -838,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
- v = v.cuda()
+ v = v.to(get_current_device())
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py
index 721da69d0..6dd0a5fc3 100644
--- a/colossalai/moe/__init__.py
+++ b/colossalai/moe/__init__.py
@@ -1,6 +1,7 @@
from .checkpoint import MoECheckpintIO
from .experts import MLPExperts
-from .layers import SparseMLP
+from .layers import SparseMLP, apply_load_balance
+from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
@@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator",
"SparseMLP",
"MoECheckpintIO",
+ "MOE_MANAGER",
+ "apply_load_balance",
]
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index 34342436f..01c837ee3 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Tuple
+from typing import Any, List, Optional, Tuple
import torch
import torch.distributed as dist
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None
+
+
+def _all_to_all(
+ inputs: torch.Tensor,
+ input_split_sizes: Optional[List[int]] = None,
+ output_split_sizes: Optional[List[int]] = None,
+ group=None,
+ async_op: bool = False,
+):
+ """
+ Returns:
+ outputs: Tensor
+ handle: Optional[Work], if overlap is True
+ """
+ outputs_shape = list(inputs.shape)
+ if output_split_sizes is not None:
+ outputs_shape[0] = sum(output_split_sizes)
+ outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
+ inputs = inputs.contiguous()
+ outputs = outputs.contiguous()
+ handle = dist.all_to_all_single(
+ outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
+ )
+ return outputs, handle
+
+
+class AllToAllUneven(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ inputs,
+ input_split_sizes=None,
+ output_split_sizes=None,
+ group=None,
+ overlap: bool = False,
+ ):
+ """
+ Returns:
+ outputs: Tensor
+ handle: Optional[Work], if overlap is True
+ """
+ ctx.input_split_sizes = input_split_sizes
+ ctx.output_split_sizes = output_split_sizes
+ ctx.group = group
+ return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs):
+ return (
+ _all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def all_to_all_uneven(
+ inputs: torch.Tensor,
+ input_split_sizes: Optional[List[int]] = None,
+ output_split_sizes: Optional[List[int]] = None,
+ group=None,
+ overlap: bool = False,
+):
+ return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py
index a8c50eab6..b37ffabea 100644
--- a/colossalai/moe/checkpoint.py
+++ b/colossalai/moe/checkpoint.py
@@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
+ torch.cuda.empty_cache()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"index located at {save_index_file}."
)
dist.barrier()
+ torch.cuda.empty_cache()
# ========================================================
# Abstract methods for optimizer loading/saving implementation
@@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
- param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
+ elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
+ working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]
@@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
- param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
id_map[param_id] = param
# Read checkpoint index file.
@@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
- # ep extra group
- if MOE_MANAGER.parallel == "EP":
+ # ep param group
+ if len(optimizer.optim.param_groups) > len(saved_groups):
new_pg = copy.deepcopy(saved_pg)
- new_pg["params"] = optimizer.optim.param_groups[-1][
- "params"
- ] # Only keep the parameters kept by current pipeline stage.
- for param in new_pg["params"]:
- param.data = param.data.to(torch.float32)
+ new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
@@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for param in pg["params"]:
if param is None:
continue
- param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
@@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for pid, state in list(state_dict.items()):
+ if pid in id_map:
+ param = id_map[pid]
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ elif (
+ hasattr(optimizer, "moe_master_to_working_map")
+ and id(param) in optimizer.moe_master_to_working_map
+ ):
+ working_param = optimizer.moe_master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.pre_load_optim(
+ state,
+ working_param,
+ current_shape=working_param.shape,
+ original_shape=original_shape,
+ device="cpu",
+ inplace=True,
+ )
+ state_dict[pid] = sharded_state
+
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
- # Then shard the loaded optimizer states if using tp/zero.
- for param, state in optimizer.optim.state.items():
- device = param.device
- if master_to_working_map is not None and id(param) in master_to_working_map:
- working_param = master_to_working_map[id(param)]
- else:
- working_param = param
- original_shape = optimizer.param_info["param2shape"][id(working_param)]
- sharded_state = self.pre_load_optim(
- state,
- param,
- current_shape=working_param.shape,
- original_shape=original_shape,
- device=device,
- inplace=True,
- )
- optimizer.optim.state[param] = sharded_state
-
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
+ elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
+ working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
@@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
+ torch.cuda.empty_cache()
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}."
)
+ torch.cuda.empty_cache()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py
index 477b76547..8e6ea3884 100644
--- a/colossalai/moe/experts.py
+++ b/colossalai/moe/experts.py
@@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self.ep_size = 1
if gated:
- self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
+ self.wi_gate = nn.Parameter(
+ torch.empty(
+ num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
+ )
+ )
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py
index b768fb94a..2ac5b186d 100644
--- a/colossalai/moe/layers.py
+++ b/colossalai/moe/layers.py
@@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
+ router_loss: bool = True,
+ router_norm: bool = False,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
@@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
+ return_gate_logits: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.gated = mlp_gated
+ self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel()
+ self.router_loss = router_loss
+ self.router_norm = router_norm
# moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
@@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32
- fp32_input = tokens.to(torch.float)
- fp32_weight = self.gate_weight.to(torch.float)
- gate_output = F.linear(fp32_input, fp32_weight)
+ gate_logits = F.linear(tokens, self.gate_weight)
+ gate_output = gate_logits.to(torch.float)
# update expert load
if self.enable_load_balance == True:
@@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router
used_capacity, *route_result_list = self.router(
- inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
+ inputs=gate_output,
+ use_kernel=self.enable_kernel,
+ ep_group=self.ep_group,
+ use_loss=self.router_loss,
+ use_norm=self.router_norm,
+ )
# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
@@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
- expert_output = self._ep_process(
- dispatch_data,
- used_capacity,
- overlap=self.enable_comm_overlap
- )
+ expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel == "TP":
- expert_output = self._tp_process(
- dispatch_data,
- used_capacity,
- overlap=self.enable_comm_overlap
- )
+ expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
- raise NotImplementedError("This kind of communication has not been implemented yet.\n"
- "Please use Experts build function.")
+ raise NotImplementedError(
+ "This kind of communication has not been implemented yet.\n" "Please use Experts build function."
+ )
if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size)
@@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape)
- return ans
+
+ if self.return_gate_logits:
+ return ans, gate_logits
+ else:
+ return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0)
@@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return expert_out
def _ep_process(
- self,
- dispatch_data: torch.Tensor,
- used_capacity: torch.Tensor,
- overlap: bool = False
+ self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
Expert Parallel
@@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None:
- expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
+ expert_input = HierarchicalAllToAll.apply(
+ dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
+ )
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
- expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
+ expert_output = HierarchicalAllToAll.apply(
+ expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
+ )
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
@@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
- assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
+ assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape)
@@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
- output[:, :, offset:offset + chunk_size, :] = expert_out.data
+ output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size
expert_out = None
# all2all last output
if _expert_out is not None:
- expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
+ expert_out = Capsule(
+ *AllToAll.apply(_expert_out.data, self.ep_group, True),
+ )
_expert_out = None
# all2all next input
@@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return output
def _tp_process(
- self,
- dispatch_data: torch.Tensor,
- used_capacity: torch.Tensor,
- overlap: bool = False
+ self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
) -> torch.Tensor:
"""
without overlap:
@@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
- assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
- "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
+ assert (
+ dispatch_data.shape[0] % NUM_CHUNK == 0
+ ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)
diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py
index f5815d05d..e40674c9b 100644
--- a/colossalai/moe/routers.py
+++ b/colossalai/moe/routers.py
@@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
self._z_loss = None
self.use_kernel = use_kernel
- def get_capacity(self, logits_shape):
+ def get_capacity(self, num_tokens, num_experts, ep_group=None):
+ if ep_group is not None:
+ num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
+ dist.all_reduce(num_tokens_tensor, group=ep_group)
+ num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
- capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
+ capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
@@ -150,7 +154,14 @@ class Top1Router(MoeRouter):
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample
- def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ use_kernel: bool = False,
+ ep_group: Optional[ProcessGroup] = None,
+ use_loss: bool = False,
+ use_norm: bool = False,
+ ) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
@@ -168,7 +179,8 @@ class Top1Router(MoeRouter):
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
- capacity = self.get_capacity(inputs.shape)
+ num_tokens = inputs.size(0)
+ capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
@@ -207,7 +219,7 @@ class Top1Router(MoeRouter):
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
- return used_capacity, combine_weights, sec_mask
+ return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter):
@@ -240,7 +252,14 @@ class Top2Router(MoeRouter):
drop_tks=drop_tks,
)
- def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ use_kernel: bool = False,
+ ep_group: Optional[ProcessGroup] = None,
+ use_norm: bool = False,
+ use_loss: bool = True,
+ ) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
@@ -257,8 +276,13 @@ class Top2Router(MoeRouter):
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
+ if use_norm:
+ routing_weights, _ = torch.topk(probs, 2, dim=-1)
+ probs = probs / routing_weights.sum(dim=-1, keepdim=True)
+
num_experts = probs.size(-1)
- capacity = self.get_capacity(inputs.shape)
+ num_tokens = inputs.size(0)
+ capacity = self.get_capacity(num_tokens, num_experts, ep_group)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
@@ -270,10 +294,11 @@ class Top2Router(MoeRouter):
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
- expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
- self.set_aux_loss(probs, expert_indices, num_experts)
- self.set_z_loss(inputs)
- self.pop_router_loss()
+ if use_loss:
+ expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
+ self.set_aux_loss(probs, expert_indices, num_experts)
+ self.set_z_loss(inputs)
+ self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py
index e25e7dd48..c642f1a44 100644
--- a/colossalai/moe/utils.py
+++ b/colossalai/moe/utils.py
@@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return torch.nn.GELU()
elif act == "swiglu":
return SwiGLU
+ elif act == "silu":
+ return torch.nn.SiLU()
else:
raise NotImplementedError("Unsupported activation function")
diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py
index 9d1d8f01d..e55e82280 100644
--- a/colossalai/nn/lr_scheduler/delayed.py
+++ b/colossalai/nn/lr_scheduler/delayed.py
@@ -6,6 +6,8 @@ if Version(torch.__version__) >= Version("2.0.0"):
else:
from torch.optim.lr_scheduler import _LRScheduler
+from colossalai.logging import get_dist_logger
+
class _enable_get_lr_call:
def __init__(self, o):
@@ -19,7 +21,39 @@ class _enable_get_lr_call:
self.o._get_lr_called_within_step = False
-class DelayerScheduler(_LRScheduler):
+class TwoStageScheduler(_LRScheduler):
+ def __init__(self, optimizer, after_scheduler: _LRScheduler, last_epoch=-1):
+ self.after_scheduler = after_scheduler
+ self.finished = False
+ super().__init__(optimizer, last_epoch)
+
+ def state_dict(self):
+ state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
+ if isinstance(state_dict["after_scheduler"], _LRScheduler):
+ state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
+ state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
+ del state_dict["after_scheduler"]
+ else:
+ raise NotImplementedError()
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ if "after_scheduler_dict" not in state_dict:
+ logger = get_dist_logger()
+ logger.warning(
+ "after_scheduler_dict is not found, skip loading after_scheduler. This may cause unexpected behavior."
+ )
+ else:
+ self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"])
+ state_dict = {
+ key: value
+ for key, value in state_dict.items()
+ if key not in ("after_scheduler_type", "after_scheduler_dict")
+ }
+ super().load_state_dict(state_dict)
+
+
+class DelayerScheduler(TwoStageScheduler):
"""Starts with a flat lr schedule until it reaches N epochs then applies
the specific scheduler (For example: ReduceLROnPlateau)
@@ -35,19 +69,7 @@ class DelayerScheduler(_LRScheduler):
if delay_epochs < 0:
raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}")
self.delay_epochs = delay_epochs
- self.after_scheduler = after_scheduler
- self.finished = False
- super().__init__(optimizer, last_epoch)
-
- def state_dict(self):
- state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
- if isinstance(state_dict["after_scheduler"], _LRScheduler):
- state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
- state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
- del state_dict["after_scheduler"]
- else:
- raise NotImplementedError()
- return state_dict
+ super().__init__(optimizer, after_scheduler, last_epoch)
def get_lr(self):
if self.last_epoch >= self.delay_epochs:
@@ -71,7 +93,7 @@ class DelayerScheduler(_LRScheduler):
return super(DelayerScheduler, self).step(epoch)
-class WarmupScheduler(_LRScheduler):
+class WarmupScheduler(TwoStageScheduler):
"""Starts with a linear warmup lr schedule until it reaches N epochs then applies
the specific scheduler (For example: ReduceLROnPlateau).
@@ -85,19 +107,7 @@ class WarmupScheduler(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
self.warmup_epochs = int(warmup_epochs)
- self.after_scheduler = after_scheduler
- self.finished = False
- super().__init__(optimizer, last_epoch)
-
- def state_dict(self):
- state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
- if isinstance(state_dict["after_scheduler"], _LRScheduler):
- state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
- state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
- del state_dict["after_scheduler"]
- else:
- raise NotImplementedError()
- return state_dict
+ super().__init__(optimizer, after_scheduler, last_epoch)
def get_lr(self):
if self.last_epoch >= self.warmup_epochs:
@@ -120,7 +130,7 @@ class WarmupScheduler(_LRScheduler):
return super().step(epoch)
-class WarmupDelayerScheduler(_LRScheduler):
+class WarmupDelayerScheduler(TwoStageScheduler):
"""Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule
until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau).
@@ -140,19 +150,7 @@ class WarmupDelayerScheduler(_LRScheduler):
raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}")
self.warmup_epochs = warmup_epochs
self.delay_epochs = delay_epochs
- self.after_scheduler = after_scheduler
- self.finished = False
- super().__init__(optimizer, last_epoch)
-
- def state_dict(self):
- state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
- if isinstance(state_dict["after_scheduler"], _LRScheduler):
- state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
- state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
- del state_dict["after_scheduler"]
- else:
- raise NotImplementedError()
- return state_dict
+ super().__init__(optimizer, after_scheduler, last_epoch)
def get_lr(self):
if self.last_epoch >= self.warmup_epochs + self.delay_epochs:
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index e10a7ed7d..92c709218 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
+from ..layer._operation import _gather
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -288,6 +289,9 @@ class LlamaPipelineForwards:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
+ if not shard_config.parallel_output:
+ logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
+
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
@@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
+ if not shard_config.parallel_output:
+ logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
+
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index b5c9e66e0..415fc6dd5 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -34,6 +34,7 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
+ parallel_output = True
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py
index 5301c87b9..acb9fc4ae 100644
--- a/colossalai/tensor/colo_parameter.py
+++ b/colossalai/tensor/colo_parameter.py
@@ -7,11 +7,12 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from .colo_tensor import _convert_output
-WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point}
+WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
+NO_HOOK_FUNCS = {torch.Tensor.is_floating_point}
def is_no_hook_op(func) -> bool:
- return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS
+ return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS
def filter_colo_parameters(*args, **kwargs):
diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py
index ba6c77056..5ac3c2b3a 100644
--- a/colossalai/tensor/moe_tensor/moe_info.py
+++ b/colossalai/tensor/moe_tensor/moe_info.py
@@ -26,3 +26,5 @@ class MoeParallelInfo:
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
+ self.ep_rank = self.pg.coordinate(self.ep_axis)
+ self.dp_rank = self.pg.coordinate(self.dp_axis)
diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py
index 1fe99cd89..40de43c43 100644
--- a/colossalai/tensor/param_op_hook.py
+++ b/colossalai/tensor/param_op_hook.py
@@ -92,7 +92,10 @@ class ColoParamOpHookManager:
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params)
- return PostFwdPreBwd.apply(params, arg)
+ # incase the output is a tuple, we have to flatten it
+ grad_args, other_args, grad_flags, spec = _flatten_grad_args(arg)
+ new_grad_args = PostFwdPreBwd.apply(params, *grad_args)
+ return _merge_args(new_grad_args, other_args, grad_flags, spec)
@staticmethod
def has_hook() -> bool:
@@ -113,7 +116,7 @@ class PreFwdPostBwd(torch.autograd.Function):
class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
- def forward(ctx, params, args):
+ def forward(ctx, params, *args):
ctx.params = params
return args
@@ -142,7 +145,6 @@ def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
grad_args.append(arg)
else:
other_args.append(arg)
- assert len(grad_args) > 0
return grad_args, other_args, grad_flags, spec
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index 79831cf33..bc6c9d088 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -726,11 +726,13 @@ class GeminiDDP(ModelWrapper):
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
del temp_chunk
- if self.reuse_fp16_chunk:
- for chunk_32 in chunk_list:
- chunk_16 = chunk_32.paired_chunk
- assert chunk_16 is not None
- chunk_16.payload.copy_(chunk_32.payload)
+
+ # sync running weights and master weights
+ if self.master_weights:
+ for loaded_chunk in chunk_list:
+ paired_chunk = loaded_chunk.paired_chunk
+ assert paired_chunk is not None
+ paired_chunk.payload.copy_(loaded_chunk.payload)
for name, buf in persistent_buffers.items():
if buf is not None:
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 98fbb0c50..18367af59 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -621,7 +621,10 @@ class GeminiOptimizer(OptimizerWrapper):
Return the param_groups in Pytorch format when saving to checkpoint.
"""
- param_groups = copy.deepcopy(self.param_groups_backup)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(self.optim.param_groups, self.param_groups_backup)
+ ]
# To be compatible with pytorch checkpointing,
# store extra hyperparameters used by pytorch Adam optimizer.
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index e01c852be..a2433d1b2 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy
# so we need to store them separately in param_groups
# instead of working_groups
- moe_params = list()
+ self.working_moe_params = list()
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.moe_extra_dp_pg is None:
# skip moe param
if is_moe_tensor(param):
- moe_params.append(param)
+ self.working_moe_params.append(param)
continue
group_params.append(param)
@@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank
param_group["params"] = master_param_current_rank
- # if there are moe params, store in additional group in optim
- if len(moe_params) > 0:
+ # if there are moe params, store in addtional group in optim
+ if len(self.working_moe_params) > 0:
+ self._sync_master_param = False
param_group = dict()
+ # create fp32 master param
for key, value in self.optim.param_groups[0].items():
if key != "params":
param_group[key] = value
- param_group["params"] = moe_params
+ self.master_moe_params = []
+ for param in self.working_moe_params:
+ self.master_moe_params.append(param.clone().to(torch.float32).detach())
+ # create mapping from master to working for optimizer io
+ self.moe_master_to_working_map = {}
+ for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
+ self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
+ # add to optim
+ param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group)
# initialize communication stream for
@@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
+ # update param for moe ep
+ # move grad to master param and compute norm
+ if len(self.working_moe_params) > 0:
+ moe_grads = []
+ for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
+ if master_moe_param.grad is not None:
+ raise RuntimeError("Moe param should not have grad here")
+ grad = working_moe_param.grad
+ # no need to copy fp32 grad if master_weights is False
+ if self._master_weights:
+ grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
+ master_moe_param.grad = grad
+ working_moe_param.grad = None
+ moe_grads.append(grad)
+ grad_partition_groups.append(grad)
+ norm_group = self._compute_grad_norm(gradients=moe_grads)
+ norm_groups.append(norm_group)
+ self.optim.param_groups[-1]["params"] = self.master_moe_params
+ del moe_grads
+
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
- # TODO: we should store master param for ep
- if len(self.param_groups) > len(self._working_param_groups):
- for param in self.param_groups[-1]["params"]:
- param.data = param.data.to(torch.float32)
- param.grad = param.grad.to(torch.float32)
-
# update the parameters
self.optim.step()
- # release the moe gradm
- if len(self.param_groups) > len(self._working_param_groups):
- for param in self.param_groups[-1]["params"]:
- param.grad = None
- param.data = param.data.to(self._dtype)
+ # release moe grad
+ if len(self.working_moe_params) > 0:
+ for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
+ master_moe_param.grad = None
+ working_moe_param.data = (
+ master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
+ )
# release the grad
grad_partition_groups = []
@@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
+ if hasattr(self, "master_moe_params"):
+ for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
+ master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
+ if hasattr(self, "moe_master_to_working_map"):
+ return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
return self._param_store.master_to_working_param
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 0c438c726..c25f19795 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -9,7 +9,7 @@
文档 |
例程 |
论坛 |
- 博客
+ 博客
[](https://github.com/hpcaitech/ColossalAI/stargazers)
[](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml)
diff --git a/docs/source/en/get_started/installation.md b/docs/source/en/get_started/installation.md
index 18607a34c..f9c8fe475 100644
--- a/docs/source/en/get_started/installation.md
+++ b/docs/source/en/get_started/installation.md
@@ -23,7 +23,7 @@ pip install colossalai
If you want to build PyTorch extensions during installation, you can use the command below. Otherwise, the PyTorch extensions will be built during runtime.
```shell
-CUDA_EXT=1 pip install colossalai
+BUILD_EXT=1 pip install colossalai
```
@@ -39,7 +39,7 @@ cd ColossalAI
pip install -r requirements/requirements.txt
# install colossalai
-CUDA_EXT=1 pip install .
+BUILD_EXT=1 pip install .
```
If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `CUDA_EXT`:
@@ -61,7 +61,7 @@ unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
# install
-CUDA_EXT=1 pip install .
+BUILD_EXT=1 pip install .
```
diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md
index fb6fd90ec..481efe98a 100644
--- a/docs/source/zh-Hans/features/1D_tensor_parallel.md
+++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md
@@ -19,10 +19,8 @@
当第二个线性层 $Z=YB$ 跟随上述列并行层的时候, 我们把 $B$ 划分为
$$
\left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right]
-```
-这就是所谓的行并行方式.
$$
-
+这就是所谓的行并行方式.
为了计算
$$
Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right]
diff --git a/docs/source/zh-Hans/get_started/installation.md b/docs/source/zh-Hans/get_started/installation.md
index e75e42530..9e4f34707 100755
--- a/docs/source/zh-Hans/get_started/installation.md
+++ b/docs/source/zh-Hans/get_started/installation.md
@@ -20,10 +20,10 @@ pip install colossalai
**注:现在只支持Linux。**
-如果你想同时安装PyTorch扩展的话,可以添加`CUDA_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。
+如果你想同时安装PyTorch扩展的话,可以添加`BUILD_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。
```shell
-CUDA_EXT=1 pip install colossalai
+BUILD_EXT=1 pip install colossalai
```
## 从源安装
@@ -38,10 +38,10 @@ cd ColossalAI
pip install -r requirements/requirements.txt
# install colossalai
-CUDA_EXT=1 pip install .
+BUILD_EXT=1 pip install .
```
-如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`CUDA_EXT=1`:
+如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`BUILD_EXT=1`:
```shell
pip install .
@@ -60,7 +60,7 @@ unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
# install
-CUDA_EXT=1 pip install .
+BUILD_EXT=1 pip install .
```
diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py
deleted file mode 100644
index 2b2356b18..000000000
--- a/examples/language/llama2/attn.py
+++ /dev/null
@@ -1,84 +0,0 @@
-from types import MethodType
-from typing import Optional, Tuple
-
-import torch
-import torch.nn as nn
-from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
-
-SUPPORT_XFORMERS = False
-SUPPORT_FLASH2 = False
-try:
- import xformers.ops as xops
-
- SUPPORT_XFORMERS = True
-except ImportError:
- pass
-
-try:
- from flash_attn import flash_attn_func
-
- SUPPORT_FLASH2 = True
-except ImportError:
- pass
-
-SUPPORT_FLASH = SUPPORT_XFORMERS or SUPPORT_FLASH2
-
-
-def llama_flash_attention(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
- # [bsz, nh, t, hd]
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- # q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K]
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- if SUPPORT_FLASH2:
- attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
- else:
- attn_output = xops.memory_efficient_attention(
- query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
- )
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-def replace_xformers(model: nn.Module):
- for module in model.modules():
- if isinstance(module, LlamaAttention):
- module.forward = MethodType(llama_flash_attention, module)
diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py
new file mode 120000
index 000000000..4e95c7bfa
--- /dev/null
+++ b/examples/language/llama2/attn.py
@@ -0,0 +1 @@
+../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
\ No newline at end of file
diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py
index b8f70ce9c..54b023f64 100644
--- a/examples/language/llama2/benchmark.py
+++ b/examples/language/llama2/benchmark.py
@@ -3,7 +3,7 @@ import resource
from contextlib import nullcontext
import torch
-from attn import SUPPORT_FLASH, replace_xformers
+from attn import replace_with_flash_attention
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator
@@ -188,8 +188,7 @@ def main():
model.gradient_checkpointing_enable()
if args.xformers:
- assert SUPPORT_FLASH, "Use flash attention while xfomers is not installed"
- replace_xformers(model)
+ replace_with_flash_attention(model)
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py
index 66b540076..3dbd0cf35 100644
--- a/examples/language/llama2/finetune.py
+++ b/examples/language/llama2/finetune.py
@@ -9,7 +9,7 @@ from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
-from attn import SUPPORT_XFORMERS, replace_xformers
+from attn import replace_with_flash_attention
from data_utils import load_json, prepare_dataloader, save_json
from datasets import load_dataset
from torch.optim import Optimizer
@@ -219,8 +219,7 @@ def main():
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if args.flash_attention:
- assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed"
- replace_xformers(model)
+ replace_with_flash_attention(model)
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py
index 4cdf93e19..fe7d95830 100644
--- a/examples/language/llama2/pretrain.py
+++ b/examples/language/llama2/pretrain.py
@@ -8,7 +8,7 @@ from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
-from attn import SUPPORT_XFORMERS, replace_xformers
+from attn import replace_with_flash_attention
from data_utils import load_json, prepare_dataloader, save_json
from datasets import load_dataset
from torch.optim import Optimizer
@@ -238,8 +238,7 @@ def main():
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if args.flash_attention:
- assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed"
- replace_xformers(model)
+ replace_with_flash_attention(model)
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
diff --git a/extensions/cpp_extension.py b/extensions/cpp_extension.py
index b4c40c9f1..3adb65fb8 100644
--- a/extensions/cpp_extension.py
+++ b/extensions/cpp_extension.py
@@ -126,7 +126,7 @@ class _CppExtension(_Extension):
def load(self):
try:
op_kernel = self.import_op()
- except ImportError:
+ except (ImportError, ModuleNotFoundError):
# if import error occurs, it means that the kernel is not pre-built
# so we build it jit
op_kernel = self.build_jit()
diff --git a/setup.py b/setup.py
index 1244bfff0..e54ec41ea 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,5 @@
import os
import sys
-from datetime import datetime
from typing import List
from setuptools import find_packages, setup
@@ -15,7 +14,6 @@ except ImportError:
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
BUILD_EXT = int(os.environ.get("BUILD_EXT", "0")) == 1
-IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1
# we do not support windows currently
if sys.platform == "win32":
@@ -96,23 +94,15 @@ if BUILD_EXT:
else:
ext_modules = []
-# always put not nightly branch as the if branch
-# otherwise github will treat colossalai-nightly as the project name
-# and it will mess up with the dependency graph insights
-if not IS_NIGHTLY:
- version = get_version()
- package_name = "colossalai"
-else:
- # use date as the nightly version
- version = datetime.today().strftime("%Y.%m.%d")
- package_name = "colossalai-nightly"
+version = get_version()
+package_name = "colossalai"
setup(
name=package_name,
version=version,
packages=find_packages(
exclude=(
- "op_builder",
+ "extensions",
"benchmark",
"docker",
"tests",
@@ -121,8 +111,9 @@ setup(
"tests",
"scripts",
"requirements",
+ "extensions",
"*.egg-info",
- )
+ ),
),
description="An integrated large-scale model training system with efficient parallelization techniques",
long_description=fetch_readme(),
@@ -153,10 +144,7 @@ setup(
],
package_data={
"colossalai": [
- "_C/*.pyi",
- "kernel/cuda_native/csrc/*",
- "kernel/cuda_native/csrc/kernel/*",
- "kernel/cuda_native/csrc/kernels/include/*",
+ "kernel/extensions/csrc/**/*",
]
},
)
diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py
index 67b0bef50..d629e769d 100644
--- a/tests/test_booster/test_plugin/test_3d_plugin.py
+++ b/tests/test_booster/test_plugin/test_3d_plugin.py
@@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
@parameterize(
"test_args",
[
+ {
+ "batch_size": 8,
+ "num_steps": 4,
+ "tp": 2,
+ "pp": 2,
+ "pp_style": "1f1b",
+ "num_model_chunks": 1,
+ "num_microbatches": 4,
+ "zero": 1,
+ "precision": "fp16",
+ "initial_scale": 1,
+ "max_length": 512,
+ "gradient_accumulation_step": 2,
+ },
{
"batch_size": 8,
"num_steps": 4,
diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
index 708a1906b..61cac1d83 100644
--- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
@@ -97,7 +97,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
new_model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
- new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
+ new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
data = data_gen_fn()
@@ -109,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
booster.backward(loss, optimizer)
optimizer.step()
+ for group in optimizer.param_groups:
+ group["lr"] = 0.1
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
@@ -127,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
check_state_dict_equal(
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
)
+ for group in new_optimizer.param_groups:
+ assert group["lr"] == 0.1
# Check the new model/optimizer can successfully run.
data = data_gen_fn()
diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
index a42b550cd..b5cb31715 100644
--- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
@@ -83,7 +83,8 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
optimizer.backward(loss)
optimizer.step()
-
+ for group in optimizer.param_groups:
+ group["lr"] = 0.1
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
index dd41f8185..dca562a3b 100644
--- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
@@ -10,6 +10,7 @@ from colossalai.booster import Booster
if version.parse(torch.__version__) >= version.parse("1.12.0"):
from colossalai.booster.plugin import TorchFSDPPlugin
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from colossalai.testing import rerun_if_address_is_in_use, spawn
@@ -99,6 +100,43 @@ def check_torch_fsdp_ckpt():
outputs_sec = fsdp_model(inputs)
assert criterion(outputs_sec) == criterion(outputs)
+ with shared_tempdir() as tempdir:
+ model_ckpt_path = f"{tempdir}/model"
+ optim_ckpt_path = f"{tempdir}/optimizer"
+
+ run_model()
+
+ booster.save_model(fsdp_model, model_ckpt_path, shard=True)
+ booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)
+
+ full_msd = fsdp_model.unwrap().state_dict()
+ full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
+
+ import copy
+ sharded_osd = copy.deepcopy(full_osd)
+
+ run_model()
+
+ full_msd_updated = fsdp_model.unwrap().state_dict()
+ full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
+
+ # cost much time led to timeout
+ # assert not compare_nested_dict(full_osd_updated, sharded_osd)
+ # assert not compare_nested_dict(full_msd_updated, full_msd)
+ outputs_first = fsdp_model(inputs)
+ assert criterion(outputs_first) != criterion(outputs)
+
+ booster.load_model(fsdp_model, model_ckpt_path)
+ booster.load_optimizer(optimizer, optim_ckpt_path)
+
+ full_msd_restore = fsdp_model.unwrap().state_dict()
+ sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)
+
+ assert compare_nested_dict(sharded_osd, sharded_osd_restore)
+ assert compare_nested_dict(full_msd_restore, full_msd)
+ outputs_sec = fsdp_model(inputs)
+ assert criterion(outputs_sec) == criterion(outputs)
+
def run_dist(rank, world_size, port):
# init dist env
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index 721a4796a..17b790e3e 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -1,13 +1,22 @@
import torch
import torch.distributed as dist
import torch.nn as nn
+from torch.testing import assert_close
+from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict
+from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
+
+
+def delete_moe_info(model):
+ for _, param in model.named_parameters():
+ if hasattr(param, "moe_info"):
+ delattr(param, "moe_info")
class MoeModel(nn.Module):
@@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i + 1]
- assert not torch.allclose(a, b), \
- (f"expected tensors on rank {i} and {i + 1} not to be equal "
- f"but they are, {a} vs {b}")
+ assert not torch.allclose(a, b), (
+ f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
+ )
+
+
+def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
+ model.train()
+ with torch.cuda.amp.autocast(enabled=enable_autocast):
+ if criterion:
+ y = model(data)
+ loss = criterion(y, label)
+ else:
+ loss = model(data, label)
+ loss = loss.float()
+
+ if isinstance(model, LowLevelZeroModel):
+ optimizer.backward(loss)
+ else:
+ loss.backward()
+ return y
+
+
+def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
+ """Sync the parameters of tp model from ep model
+
+ Args:
+ local_model (MoeModule)
+ ep_model (MoeModule)
+ """
+ for (local_name, local_param), (ep_name, ep_param) in zip(
+ local_model.named_parameters(), ep_model.named_parameters()
+ ):
+ assert local_name in ep_name, print(f"{local_name} != {ep_name}")
+ if "experts" not in local_name:
+ if assert_grad_flag:
+ assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
+ assert torch.allclose(local_param.grad, ep_param.grad)
+ else:
+ local_param.data.copy_(ep_param.data)
+ continue
+
+ # gather param from ep model
+ param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
+ all_param = torch.cat(param_list, dim=0)
+ if assert_grad_flag:
+ grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
+ all_grad = torch.cat(grad_list, dim=0)
+
+ if assert_grad_flag:
+ assert torch.allclose(local_param, all_param)
+ assert torch.allclose(local_param.grad, all_grad)
+ else:
+ local_param.data.copy_(all_param.data)
+
+
+def loose_close(a, b, dtype: torch.dtype = torch.float32):
+ rtol = None
+ atol = None
+ if dtype is torch.float16:
+ rtol = 5e-2
+ atol = 5e-4
+ elif dtype is torch.bfloat16:
+ rtol = 4e-3
+ atol = 4e-3
+
+ a = a.detach().to(dtype)
+ b = b.detach().to(dtype).to(a.device)
+
+ assert_close(a, b, rtol=rtol, atol=atol)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 8f51e1663..d6dad2d7f 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -12,7 +12,6 @@ import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
sys.path.append(
@@ -95,6 +94,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
+ ep_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
@@ -103,6 +103,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
+ ep_size=dist.get_world_size(),
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
@@ -111,6 +112,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
+ ep_size=2,
zero_stage=2,
extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(),
@@ -120,6 +122,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=2,
+ ep_size=2,
zero_stage=1,
microbatch_size=1,
custom_policy=OpenMoeForCausalLMPolicy(),
@@ -130,27 +133,6 @@ def get_model(parallel):
def _test_moe_checkpoint(rank, parallel):
- if parallel == None:
- MOE_MANAGER.setup(
- parallel=None,
- )
- elif parallel == "ep":
- MOE_MANAGER.setup(
- parallel="EP",
- )
- elif parallel == "ep_zero":
- MOE_MANAGER.setup(
- parallel="EP",
- max_ep_size=2,
- )
- elif parallel == "hybrid":
- MOE_MANAGER.setup(
- parallel="EP",
- mode="fixed",
- fixed_dp_size=1,
- fixed_ep_size=2,
- fixed_pp_size=2,
- )
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel)
@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
_test_moe_checkpoint(rank, parallel)
+@pytest.mark.skip(reason="This is tested in ColossalMOE")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py
index 7ba7fa6f6..9f6167692 100644
--- a/tests/test_moe/test_moe_router.py
+++ b/tests/test_moe/test_moe_router.py
@@ -4,15 +4,21 @@ import torch
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
-@pytest.mark.parametrize(["router", "num_groups"], [
- (Top1Router(), 1),
- (Top2Router(), 1),
- # (TopKRouter(num_selected_experts=3), 4),
-])
-@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
- (4, 5, 8),
- (3, 4, 4),
-])
+@pytest.mark.parametrize(
+ ["router", "num_groups"],
+ [
+ (Top1Router(), 1),
+ (Top2Router(), 1),
+ # (TopKRouter(num_selected_experts=3), 4),
+ ],
+)
+@pytest.mark.parametrize(
+ ["batch_size", "seq_len", "num_experts"],
+ [
+ (4, 5, 8),
+ (3, 4, 4),
+ ],
+)
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
if num_groups > 1:
@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
router.train()
if isinstance(router, TopKRouter):
- _, combine_array, dispatch_mask = router(x, expert_capacity=2)
+ combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
- _, combine_array, dispatch_mask = router(x)
+ combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
router.eval()
if isinstance(router, TopKRouter):
- _, combine_array, dispatch_mask = router(x, expert_capacity=2)
+ combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
- _, combine_array, dispatch_mask = router(x)
+ combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py
index f0795a4c7..1bff21066 100644
--- a/tests/test_moe/test_moe_zero_fwd_bwd.py
+++ b/tests/test_moe/test_moe_zero_fwd_bwd.py
@@ -4,102 +4,75 @@ import torch
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
-from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
-from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
+from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
-def split_ddp_grad(grad, world_size):
- with torch.no_grad():
- grad = grad.clone().detach().flatten()
- padding_size = (world_size - grad.numel() % world_size) % world_size
- if padding_size > 0:
- grad = torch.nn.functional.pad(grad, [0, padding_size])
- splited_grad = grad.split(grad.numel() // world_size)
- return splited_grad
-
-
-def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
- model.train()
- with torch.cuda.amp.autocast(enabled=enable_autocast):
- if criterion:
- y = model(data)
- loss = criterion(y, label)
- else:
- loss = model(data, label)
- loss = loss.float()
-
- if isinstance(model, LowLevelZeroModel):
- optimizer.backward(loss)
- else:
- loss.backward()
- return y
-
-
-def run_zero_test(local_rank, world_size, stage=1):
+def run_zero_test(local_rank, stage=1):
criterion = torch.nn.CrossEntropyLoss()
- zero_model = MoeModel()
- optimizer = torch.optim.Adam(zero_model.parameters())
- plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
- booster = Booster(plugin=plugin)
- zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer)
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel="EP")
+ moe_model = MoeModel().bfloat16()
+ moe_optimizer = torch.optim.Adam(moe_model.parameters())
+ moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
+ moe_booster = Booster(plugin=moe_plugin)
+ moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
- torch_model = MoeModel()
- for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
- torch_param.data.copy_(zero_param.data)
- torch_model = torch_model.cuda()
- grad_handler = MoeGradientHandler(torch_model)
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel=None)
+ zero_model = MoeModel().bfloat16()
+ delete_moe_info(zero_model)
+ zero_optimizer = torch.optim.Adam(zero_model.parameters())
+ zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
+ zero_booster = Booster(plugin=zero_plugin)
+ zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
+ sync_local_from_ep(zero_model, moe_model)
- # assert zero model
- for (torch_name, torch_param), (zero_name, zero_param) in zip(
- torch_model.named_parameters(), zero_model.module.named_parameters()
- ):
- assert zero_name == torch_name
- assert torch.allclose(zero_param.data, torch_param.data)
-
- data = torch.randn(16, 4).cuda()
+ data = torch.randn(16, 4).bfloat16().cuda()
label = torch.randint(0, 4, (16,)).cuda()
- torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
- zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer)
- assert torch.allclose(torch_out, zero_out)
- grad_handler.handle_gradient()
+ zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
+ assert torch.allclose(zero_out, moe_out)
- for (zero_name, zero_param), (torch_name, torch_param) in zip(
- zero_model.module.named_parameters(), torch_model.named_parameters()
+ for (moe_name, moe_param), (zero_name, zero_param) in zip(
+ moe_model.module.named_parameters(), zero_model.module.named_parameters()
):
- assert zero_name == torch_name
- zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
- if hasattr(zero_param, "moe_info"):
- assert len(zero_grad_list) == 0
- assert torch.allclose(zero_param.grad, torch_param.grad)
+ assert moe_name == zero_name
+ moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
+ zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
+ if hasattr(moe_param, "moe_info"):
+ assert len(moe_grad_list) == 0
+ if stage == 1:
+ zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
+ else:
+ zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
+ assert torch.allclose(
+ moe_param.grad, zero_grad, atol=1e-5
+ ), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
else:
- assert len(zero_grad_list) > 0
- torch_grad_list = split_ddp_grad(torch_param.grad, world_size)
- if stage == 2:
- torch_grad_list = torch_grad_list[local_rank : local_rank + 1]
- assert len(zero_grad_list) == len(torch_grad_list)
- for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
- assert torch.allclose(zero_grad, torch_grad)
+ assert len(moe_grad_list) > 0
+ assert len(moe_grad_list) == len(zero_grad_list)
+ for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
+ assert torch.allclose(moe_grad, zero_grad)
-def run_dist(rank, world_size, port):
+def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- MOE_MANAGER.setup(parallel="EP")
seed_all(42 + rank)
- run_zero_test(rank, world_size, stage=1)
- run_zero_test(rank, world_size, stage=2)
+ run_zero_test(rank, stage=stage)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
+@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use()
-def test_moe_zero_model(world_size):
- spawn(run_dist, world_size)
+def test_moe_zero_model(world_size, stage):
+ spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__":
- test_moe_zero_model(world_size=2)
+ test_moe_zero_model(world_size=2, stage=1)
diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py
index 0d2e2fb1b..4f6067aaa 100644
--- a/tests/test_moe/test_moe_zero_optim.py
+++ b/tests/test_moe/test_moe_zero_optim.py
@@ -4,89 +4,80 @@ import torch
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
-from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
-from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
-def split_ddp_grad(grad, world_size):
- with torch.no_grad():
- grad = grad.clone().detach().flatten()
- padding_size = (world_size - grad.numel() % world_size) % world_size
- if padding_size > 0:
- grad = torch.nn.functional.pad(grad, [0, padding_size])
- splited_grad = grad.split(grad.numel() // world_size)
- return splited_grad
-
-
-def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
- model.train()
- with torch.cuda.amp.autocast(enabled=enable_autocast):
- if criterion:
- y = model(data)
- loss = criterion(y, label)
- else:
- loss = model(data, label)
- loss = loss.float()
-
- if isinstance(model, LowLevelZeroModel):
- optimizer.backward(loss)
- else:
- loss.backward()
- return y
-
-
-def run_zero_optim_test(local_rank, world_size, stage=1):
+def run_zero_test(local_rank, stage=1):
criterion = torch.nn.CrossEntropyLoss()
- zero_model = MoeModel()
- zero_optimizer = torch.optim.Adam(zero_model.parameters())
- plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
- booster = Booster(plugin=plugin)
- zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel="EP")
+ moe_model = MoeModel().bfloat16()
+ moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
+ moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
+ moe_booster = Booster(plugin=moe_plugin)
+ moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
- torch_model = MoeModel()
- for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
- torch_param.data.copy_(zero_param.data)
- torch_optimizer = torch.optim.Adam(torch_model.parameters())
- torch_model = torch_model.cuda()
- grad_handler = MoeGradientHandler(torch_model)
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel=None)
+ zero_model = MoeModel().bfloat16()
+ delete_moe_info(zero_model)
+ sync_local_from_ep(zero_model, moe_model)
+ zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
+ zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
+ zero_booster = Booster(plugin=zero_plugin)
+ zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
- for _ in range(2):
- data = torch.randn(16, 4).cuda() / (local_rank + 1)
- label = torch.randint(0, 4, (16,)).cuda()
- run_fwd_bwd(torch_model, data, label, criterion, None)
- run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
- grad_handler.handle_gradient()
+ for (moe_name, moe_param), (zero_name, zero_param) in zip(
+ moe_model.named_parameters(), zero_model.named_parameters()
+ ):
+ if ".experts." in moe_name:
+ continue
+ assert moe_name == zero_name
+ assert torch.allclose(
+ moe_param.data, zero_param.data
+ ), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
- torch_optimizer.step()
+ for _ in range(1):
+ data = torch.randn(2, 4).bfloat16().cuda()
+ label = torch.randint(0, 4, (2,)).cuda()
+
+ moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
+ zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ assert torch.allclose(zero_out, moe_out)
+ moe_optimizer.step()
zero_optimizer.step()
- for (torch_name, torch_param), (zero_name, zero_param) in zip(
- torch_model.named_parameters(), zero_model.named_parameters()
+ for (moe_name, moe_param), (zero_name, zero_param) in zip(
+ moe_model.named_parameters(), zero_model.named_parameters()
):
- assert torch.allclose(
- torch_param.data, zero_param.data
- ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
+ assert moe_name == zero_name
+ if is_moe_tensor(moe_param):
+ param_size = moe_param.shape[0]
+ zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
+ loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
- torch_optimizer.zero_grad()
+ moe_optimizer.zero_grad()
zero_optimizer.zero_grad()
-def run_dist(rank, world_size, port):
+def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- MOE_MANAGER.setup(parallel="EP")
- run_zero_optim_test(rank, world_size, stage=1)
- run_zero_optim_test(rank, world_size, stage=2)
+ seed_all(42 + rank)
+ run_zero_test(rank, stage=stage)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
+@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use()
-def test_moe_zero_optim(world_size):
- spawn(run_dist, world_size)
+def test_moe_zero_optim(world_size, stage):
+ spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__":
- test_moe_zero_optim(world_size=2)
+ test_moe_zero_optim(world_size=2, stage=1)
diff --git a/tests/test_optimizer/test_lr_scheduler.py b/tests/test_optimizer/test_lr_scheduler.py
new file mode 100644
index 000000000..e0b084140
--- /dev/null
+++ b/tests/test_optimizer/test_lr_scheduler.py
@@ -0,0 +1,20 @@
+import torch.nn as nn
+from torch.optim import Adam
+
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+
+
+def test_lr_scheduler_save_load():
+ model = nn.Linear(10, 10)
+ optimizer = Adam(model.parameters(), lr=1e-3)
+ scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
+ new_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2)
+ for _ in range(5):
+ scheduler.step()
+ state_dict = scheduler.state_dict()
+ new_scheduler.load_state_dict(state_dict)
+ assert state_dict == new_scheduler.state_dict()
+
+
+if __name__ == "__main__":
+ test_lr_scheduler_save_load()
diff --git a/version.txt b/version.txt
index 42045acae..c2c0004f0 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.3.4
+0.3.5