"
+)
+```
+
+As for chatglm2 model, it should be:
+```python
+"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
+ file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
+)
+```
+
+When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
+
### Write Your Unit Testing
This section serves as the guideline for testing the `shardformer` module.
@@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows.
-| N_CTX | org_model | shard_model |
-| :------: | :-----: | :-----: |
-| 256 | 11.2ms | 17.2ms |
-| 512 | 9.8ms | 19.5ms |
-| 1024 | 19.6ms | 18.9ms |
-| 2048 | 46.6ms | 30.8ms |
-| 4096 | 160.5ms | 90.4ms |
+| N_CTX | org_model | shard_model |
+|:-----:|:---------:|:-----------:|
+| 256 | 11.2ms | 17.2ms |
+| 512 | 9.8ms | 19.5ms |
+| 1024 | 19.6ms | 18.9ms |
+| 2048 | 46.6ms | 30.8ms |
+| 4096 | 160.5ms | 90.4ms |
@@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
In the case of using 4 GPUs, the training times are as follows.
-| N_CTX | org_model | shard_model |
-| :------: | :-----: | :-----: |
-| 256 | 10.0ms | 21.1ms |
-| 512 | 11.5ms | 20.2ms |
-| 1024 | 22.1ms | 20.6ms |
-| 2048 | 46.9ms | 24.8ms |
-| 4096 | 160.4ms | 68.0ms |
+| N_CTX | org_model | shard_model |
+|:-----:|:---------:|:-----------:|
+| 256 | 10.0ms | 21.1ms |
+| 512 | 11.5ms | 20.2ms |
+| 1024 | 22.1ms | 20.6ms |
+| 2048 | 46.9ms | 24.8ms |
+| 4096 | 160.4ms | 68.0ms |
@@ -475,10 +512,10 @@ warmup_fraction = 0.03
| accuracy | f1 | loss | GPU number | model sharded |
-| :------: | :-----: | :-----: | :--------: | :---------: |
-| 0.82971 | 0.87713 | 0.23194 | 4 | True |
-| 0.83797 | 0.88006 | 0.22683 | 2 | True |
-| 0.84521 | 0.88700 | 0.21822 | 1 | False |
+|:--------:|:-------:|:-------:|:----------:|:-------------:|
+| 0.82971 | 0.87713 | 0.23194 | 4 | True |
+| 0.83797 | 0.88006 | 0.22683 | 2 | True |
+| 0.84521 | 0.88700 | 0.21822 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py
index b03e6201d..4caf61eb4 100644
--- a/colossalai/shardformer/examples/convergence_benchmark.py
+++ b/colossalai/shardformer/examples/convergence_benchmark.py
@@ -28,7 +28,7 @@ def to_device(x: Any, device: torch.device) -> Any:
def train(args):
- colossalai.launch_from_torch(config={}, seed=42)
+ colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# prepare for data and dataset
diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py
index 81215dcdf..cce8b6f3a 100644
--- a/colossalai/shardformer/examples/performance_benchmark.py
+++ b/colossalai/shardformer/examples/performance_benchmark.py
@@ -1,6 +1,7 @@
"""
Shardformer Benchmark
"""
+
import torch
import torch.distributed as dist
import transformers
@@ -84,5 +85,5 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
- colossalai.launch_from_torch({})
+ colossalai.launch_from_torch()
bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0)
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index 7b8aa5380..f17fad1b6 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -1,8 +1,8 @@
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
-from .embedding import Embedding1D, VocabParallelEmbedding1D
-from .linear import Linear1D_Col, Linear1D_Row
+from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
+from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
@@ -25,6 +25,9 @@ __all__ = [
"FusedRMSNorm",
"FusedLinear1D_Col",
"ParallelModule",
+ "PaddingEmbedding",
+ "PaddingLMHead",
+ "VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
"all_to_all_comm",
diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
index f3f6e59d3..abc865a34 100644
--- a/colossalai/shardformer/layer/attn.py
+++ b/colossalai/shardformer/layer/attn.py
@@ -8,7 +8,6 @@ from colossalai.kernel.kernel_loader import (
FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
- FlashAttentionWithPaddingMaskLoader,
KernelLoader,
)
@@ -65,15 +64,17 @@ class ColoAttention:
half_dispatch_map = {
None: FlashAttentionLoader(),
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
- AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
+ AttnMaskType.PADDED: FlashAttentionLoader(),
AttnMaskType.CAUSAL: FlashAttentionLoader(),
- AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
+ AttnMaskType.PADDED_CAUSAL: FlashAttentionLoader(),
}
# fp32
float_dispatch_map = {
None: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
+ AttnMaskType.PADDED: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
+ AttnMaskType.PADDED_CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
}
ColoAttention._kernel_dispatch_map = {
torch.float16: half_dispatch_map,
@@ -140,16 +141,22 @@ class ColoAttention:
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
else:
+ assert q_padding_mask.shape == (
+ b,
+ s_q,
+ ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})"
+ max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
# self attention
kv_padding_mask = q_padding_mask
- assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
+ max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices
+ else:
+ max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
+ assert kv_padding_mask.shape == (
b,
s_kv,
- ), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
- attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
- max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
- max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
+ ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
+ attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py
index d081b2040..cb7eceae4 100644
--- a/colossalai/shardformer/layer/embedding.py
+++ b/colossalai/shardformer/layer/embedding.py
@@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import gather_forward_split_backward, reduce_forward
-from .parallel_module import ParallelModule
+from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
-__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
+__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]
class Embedding1D(ParallelModule):
@@ -161,7 +161,80 @@ class Embedding1D(ParallelModule):
return output_parallel
-class VocabParallelEmbedding1D(ParallelModule):
+class PaddingEmbedding(PaddingParallelModule):
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: int = None,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+ weight: Optional[nn.Parameter] = None,
+ make_vocab_size_divisible_by: int = 64,
+ *args,
+ **kwargs,
+ ):
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.embed_args = args
+ self.embed_kwargs = kwargs
+ self.padding_idx = padding_idx
+ if num_embeddings % make_vocab_size_divisible_by != 0:
+ self.num_embeddings = (
+ num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
+ )
+ # create weight and bias
+ if weight is None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+
+ super().__init__(self.num_embeddings, num_embeddings, weight)
+
+ if weight is None:
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ init.normal_(self.weight)
+ self._fill_padding_idx_with_zero()
+
+ def _fill_padding_idx_with_zero(self) -> None:
+ if self.padding_idx is not None:
+ with torch.no_grad():
+ self.weight[self.padding_idx].fill_(0)
+
+ def forward(self, input: Tensor) -> Tensor:
+ return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> PaddingParallelModule:
+ r"""
+ Convert a native pytorch embedding module to a parallel module.
+ """
+ LazyInitContext.materialize(module)
+ # get the origin attributes
+ num_embeddings = module.num_embeddings
+ embedding_dim = module.embedding_dim
+ padding_idx = module.padding_idx
+ device = module.weight.device
+ # create the parallel module
+ padding_embedding = PaddingEmbedding(
+ num_embeddings=num_embeddings,
+ embedding_dim=embedding_dim,
+ padding_idx=padding_idx,
+ device=device,
+ weight=module.weight,
+ *args,
+ **kwargs,
+ )
+
+ return padding_embedding
+
+
+class VocabParallelEmbedding1D(PaddingParallelModule):
r"""Embedding parallelized in the vocabulary dimension.
Args:
@@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule):
process_group: ProcessGroup = None,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
+ make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
- super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
@@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule):
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
- self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
- self.num_embeddings = self.num_embeddings_per_partition
+ # generate weight and bias
+ if weight is None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+
+ # calculate new padding size
+ multiple = make_vocab_size_divisible_by * tensor_parallel_size
+ if num_embeddings % multiple != 0:
+ self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)
+
+ # resize vocabulary size
+ super().__init__(self.num_embeddings, num_embeddings, weight)
+
+ # deal with tensor parallelism
+ self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
@@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule):
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
- # parameter
- if weight is None:
- factory_kwargs = {"device": device, "dtype": dtype}
- self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
- else:
- weight.data = weight.data.to(device=device, dtype=dtype)
- self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
@@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
- ) -> ParallelModule:
+ ) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
@@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule):
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
-
output_parallel = F.embedding(
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
)
-
# Mask the output embedding.
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index 7c8619ad8..37c754241 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -32,7 +32,7 @@ from ._operation import (
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
-from .parallel_module import ParallelModule
+from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ["Linear1D_Col", "Linear1D_Row"]
@@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule):
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
+ **kwargs,
):
- super().__init__()
+ super().__init__(weight=weight, bias_=bias_, **kwargs)
# Keep input parameters
self.in_features = in_features
@@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule):
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
+
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
@@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule):
@staticmethod
def from_native_module(
- module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
- *args,
**kwargs,
)
@@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule):
@staticmethod
def from_native_module(
- module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
- *args,
**kwargs,
)
@@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule):
return output
else:
return output, self.bias
+
+
+class PaddingLMHead(PaddingParallelModule):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+ weight: Optional[Parameter] = None,
+ bias_: Optional[Parameter] = None,
+ make_vocab_size_divisible_by: int = 64,
+ weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
+ bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
+ ):
+ # Keep input parameters
+ self.in_features = in_features
+ self.out_features = out_features
+
+ if out_features % make_vocab_size_divisible_by != 0:
+ self.out_features = (
+ out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by)
+ )
+ if weight is None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+
+ if bias:
+ if bias_ is None:
+ self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
+ else:
+ bias_.data = bias_.data.to(device=device, dtype=dtype)
+ else:
+ bias_ = None
+
+ # resize embeddings
+ super().__init__(self.out_features, out_features, weight, bias_)
+
+ if weight is None:
+ self.reset_parameters(weight_initializer, bias_initializer)
+
+ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
+ fan_in, fan_out = self.in_features, self.out_features
+ weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
+ if self.bias is not None:
+ bias_initializer(self.bias, fan_in=fan_in)
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
+ ) -> PaddingParallelModule:
+ r"""
+ Convert a native PyTorch linear layer to a parallelized linear layer.
+ """
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+ out_features = module.out_features
+ bias = module.bias is not None
+ device = module.weight.device
+ # ensure only one process group is passed
+
+ lm_head_linear = PaddingLMHead(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ weight=module.weight,
+ bias_=module.bias,
+ **kwargs,
+ )
+
+ return lm_head_linear
+
+ def forward(self, input: Tensor) -> Tensor:
+ output = F.linear(input, self.weight, self.bias)
+ output = output[..., : self.old_num_embeddings]
+ return output
+
+
+class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
+ r"""Linear layer with column parallelism.
+
+ The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
+ its second dimension as :math:`A = [A_1, ..., A_p]`.
+
+ Args:
+ in_features (int): size of each input sample.
+ out_features (int): size of each output sample.
+ bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
+ dtype (`torch.dtype`): The dtype of parameters, defaults to None.
+ device (`torch.device`): The device of parameters, defaults to None.
+ process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
+ gather_output (bool, optional): If true, call all-gather on output and make Y available
+ to all GPUs, otherwise, every GPU will have its output
+ which is :math:`Y_i = XA_i`, defaults to False
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
+ overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
+ skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
+ which is preserved for kernel fusion, defaults to False
+ weight_initializer (`typing.Callable`):
+ The initializer of weight, defaults to kaiming uniform initializer.
+ bias_initializer (`typing.Callable`):
+ The initializer of bias, defaults to xavier uniform initializer.
+
+ More details about ``initializer`` please refer to
+ `init `_.
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+ process_group: ProcessGroup = None,
+ weight: Optional[Parameter] = None,
+ bias_: Optional[Parameter] = None,
+ make_vocab_size_divisible_by: int = 64,
+ **kwargs,
+ ):
+ # create weight and bias
+ if weight is None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
+ if bias:
+ if bias_ is None:
+ bias_ = Parameter(torch.empty(out_features, **factory_kwargs))
+ else:
+ bias_ = None
+
+ # calculate new vocab size
+ self.tensor_parallel_size = dist.get_world_size(group=process_group)
+ new_out_features = out_features
+ multiple = make_vocab_size_divisible_by * self.tensor_parallel_size
+ if out_features % multiple != 0:
+ new_out_features = out_features + multiple - (out_features % multiple)
+
+ super().__init__(
+ in_features=in_features,
+ out_features=new_out_features,
+ bias=bias,
+ device=device,
+ process_group=process_group,
+ weight=weight,
+ bias_=bias_,
+ **kwargs,
+ new_num_embeddings=new_out_features,
+ old_num_embeddings=out_features,
+ )
+ # get the length of valid embeddings
+ tp_rank = dist.get_rank(process_group)
+ partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
+ if self.old_num_embeddings >= (tp_rank + 1) * partition_size:
+ self.num_valid_embeddings_local = partition_size
+ elif self.old_num_embeddings >= tp_rank * partition_size:
+ self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size
+ else:
+ self.num_valid_embeddings_local = 0
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
+ ) -> PaddingParallelModule:
+ r"""
+ Convert a native PyTorch linear layer to a parallelized linear layer.
+ """
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+ out_features = module.out_features
+ bias = module.bias is not None
+ device = module.weight.device
+
+ lm_head_linear = VocabParallelLMHead1D(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ process_group=process_group,
+ weight=module.weight,
+ bias_=module.bias,
+ **kwargs,
+ )
+
+ return lm_head_linear
+
+ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
+ # get forward output
+ if self.skip_bias_add:
+ output, bias = super().forward(input_)
+ else:
+ output = super().forward(input_)
+
+ # delete the padding of output
+ if self.gather_output:
+ output = output[..., : self.old_num_embeddings]
+ else:
+ output = output[..., : self.num_valid_embeddings_local]
+
+ # return
+ if self.skip_bias_add:
+ return output, bias
+ return output
diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py
index c4cf3fb85..6d99efc19 100644
--- a/colossalai/shardformer/layer/loss.py
+++ b/colossalai/shardformer/layer/loss.py
@@ -15,7 +15,14 @@ class DistCrossEntropy(Function):
"""
@staticmethod
- def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
+ def forward(
+ ctx,
+ vocab_logits: torch.Tensor,
+ target: torch.Tensor,
+ ignore_index: int,
+ process_group: ProcessGroup,
+ vocab_size: int,
+ ):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
@@ -41,15 +48,21 @@ class DistCrossEntropy(Function):
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device
- partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
- global_vocab_size = partition_vocab_size * world_size
+ if vocab_size == None:
+ partition_vocab_size = vocab_logits.size()[-1]
+ global_vocab_size = partition_vocab_size * world_size
+ else:
+ global_vocab_size = vocab_size
+ partition_vocab_size = global_vocab_size // world_size
# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
+ if up_threshold > global_vocab_size:
+ up_threshold = global_vocab_size
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
@@ -57,7 +70,8 @@ class DistCrossEntropy(Function):
# reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
- logits_2d = vocab_logits.view(-1, partition_vocab_size)
+ self_vocab_size = vocab_logits.size()[-1]
+ logits_2d = vocab_logits.view(-1, self_vocab_size)
masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero
@@ -104,10 +118,14 @@ class DistCrossEntropy(Function):
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
- return grad_logits, None, None, None
+ return grad_logits, None, None, None, None
def cross_entropy_1d(
- vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
+ vocab_logits: torch.Tensor,
+ labels: torch.Tensor,
+ ignore_index: int = -100,
+ process_group: ProcessGroup = None,
+ vocab_size: int = None,
) -> torch.Tensor:
- return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
+ return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
index 43dd153af..5aa212600 100644
--- a/colossalai/shardformer/layer/normalization.py
+++ b/colossalai/shardformer/layer/normalization.py
@@ -225,7 +225,13 @@ class FusedLayerNorm(BaseLayerNorm):
# fall back to the normal fused layernorm is not built
ApexFusedLayerNorm = FusedLayerNormWithHook
else:
- ApexFusedLayerNorm = FusedLayerNormWithHook
+ try:
+ ApexFusedLayerNorm = FusedLayerNormWithHook
+ except NameError:
+ warnings.warn(
+ "Please install Apex from source to use fused kernels, or set self.enable_fused_normalization = False. Using vanilla layernorm instead."
+ )
+ return module
layernorm = (
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
@@ -275,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
)
LazyInitContext.materialize(module)
- # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
- if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
- normalized_shape = module.weight.shape[0]
- eps = module.variance_epsilon
- elementwise_affine = True
- else:
- # get the attributes of the module
- normalized_shape = module.normalized_shape
- eps = module.eps
- elementwise_affine = module.elementwise_affine
+
+ # try to get normalized_shape, eps, elementwise_affine from the module
+ normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
+ eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
+ elementwise_affine = getattr(module, "elementwise_affine", True)
rmsnorm = FusedRMSNormWithHook(
- normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
+ normalized_shape=normalized_shape,
+ eps=eps,
+ elementwise_affine=elementwise_affine,
)
rmsnorm.weight = module.weight
diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py
index 6c0d83cc7..11ef73538 100644
--- a/colossalai/shardformer/layer/parallel_module.py
+++ b/colossalai/shardformer/layer/parallel_module.py
@@ -3,7 +3,7 @@
import itertools
from abc import ABC, abstractmethod
-from typing import List, Union
+from typing import List, Optional, Union
import torch
import torch.nn as nn
@@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import (
is_distributed_tensor,
sharded_tensor_to_param,
)
+from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
__all__ = ["ParallelModule"]
class ParallelModule(nn.Module, ABC):
+ def __init__(self, **kwargs):
+ super().__init__()
+
@abstractmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
@@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC):
"""
for name, param in self._parameters.items():
if param is not None:
- destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
+ destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
@@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC):
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
+
+
+class PaddingParallelModule(ParallelModule):
+ def __init__(
+ self,
+ new_num_embeddings: int,
+ old_num_embeddings: int,
+ weight: Optional[nn.Parameter],
+ bias_: Optional[nn.Parameter] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.new_num_embeddings = new_num_embeddings
+ self.old_num_embeddings = old_num_embeddings
+ self.weight = weight
+ self.bias = bias_
+
+ if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
+ self.resize_embedding_weight()
+
+ if self.bias is not None and not (
+ is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings
+ ):
+ self.resize_embedding_bias()
+
+ @abstractmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
+ ) -> "PaddingParallelModule":
+ """
+ Convert a native PyTorch module to a parallelized module.
+
+ Args:
+ module (nn.Module): the module to be converted.
+ process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
+ If this is a list, the process group at the ith index of the list will correspond to the process group
+ in the ith axis of the device mesh. Defaults to None, which means the global process group.
+ """
+ raise NotImplementedError
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ r"""Saves module state to `destination` dictionary, containing a state
+ of the module, but not its descendants. This is called on every
+ submodule in :meth:`~torch.nn.Module.state_dict`.
+
+ In rare cases, subclasses can achieve class-specific behavior by
+ overriding this method with custom logic.
+
+ Args:
+ destination (dict): a dict where state will be stored
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ """
+ for name, param in self._parameters.items():
+ if param is not None:
+ param = gather_distributed_param(param, keep_vars=keep_vars)
+ if is_padded_tensor(param):
+ param = to_unpadded_tensor(param)
+ destination[prefix + name] = param.data
+
+ for name, buf in self._buffers.items():
+ if buf is not None and name not in self._non_persistent_buffers_set:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
+ destination[extra_state_key] = self.get_extra_state()
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ r"""Copies parameters and buffers from :attr:`state_dict` into only
+ this module, but not its descendants. This is called on every submodule
+ in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
+ module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
+ For state dicts without metadata, :attr:`local_metadata` is empty.
+ Subclasses can achieve class-specific backward compatible loading using
+ the version number at `local_metadata.get("version", None)`.
+
+ .. note::
+ :attr:`state_dict` is not the same object as the input
+ :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
+ it can be modified.
+
+ Args:
+ state_dict (dict): a dict containing parameters and
+ persistent buffers.
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ local_metadata (dict): a dict containing the metadata for this module.
+ See
+ strict (bool): whether to strictly enforce that the keys in
+ :attr:`state_dict` with :attr:`prefix` match the names of
+ parameters and buffers in this module
+ missing_keys (list of str): if ``strict=True``, add missing keys to
+ this list
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
+ keys to this list
+ error_msgs (list of str): error messages should be added to this
+ list, and will be reported together in
+ :meth:`~torch.nn.Module.load_state_dict`
+ """
+ for hook in self._load_state_dict_pre_hooks.values():
+ hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
+ local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
+ local_state = {k: v for k, v in local_name_params if v is not None}
+
+ for name, param in local_state.items():
+ key = prefix + name
+
+ if key in state_dict:
+ input_param = state_dict[key]
+ if not torch.overrides.is_tensor_like(input_param):
+ error_msgs.append(
+ 'While copying the parameter named "{}", '
+ "expected torch.Tensor or Tensor-like object from checkpoint but "
+ "received {}".format(key, type(input_param))
+ )
+ continue
+
+ if is_padded_tensor(param):
+ input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)
+
+ if is_distributed_tensor(param):
+ # shard the input param
+ device_mesh = get_device_mesh(param)
+ sharding_spec = get_sharding_spec(param)
+ sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
+ input_param = sharded_tensor_to_param(sharded_tensor)
+ elif is_customized_distributed_tensor(param):
+ input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
+
+ # This is used to avoid copying uninitialized parameters into
+ # non-lazy modules, since they dont have the hook to do the checks
+ # in such case, it will error when accessing the .shape attribute.
+ is_param_lazy = torch.nn.parameter.is_lazy(param)
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
+ if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
+ input_param = input_param[0]
+
+ if not is_param_lazy and input_param.shape != param.shape:
+ # local shape should match the one in checkpoint
+ error_msgs.append(
+ "size mismatch for {}: copying a param with shape {} from checkpoint, "
+ "the shape in current model is {}.".format(key, input_param.shape, param.shape)
+ )
+ continue
+
+ try:
+ with torch.no_grad():
+ param.copy_(input_param)
+ except Exception as ex:
+ error_msgs.append(
+ 'While copying the parameter named "{}", '
+ "whose dimensions in the model are {} and "
+ "whose dimensions in the checkpoint are {}, "
+ "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
+ )
+ elif strict:
+ missing_keys.append(key)
+
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
+ if extra_state_key in state_dict:
+ self.set_extra_state(state_dict[extra_state_key])
+ elif strict:
+ missing_keys.append(extra_state_key)
+ elif strict and (extra_state_key in state_dict):
+ unexpected_keys.append(extra_state_key)
+
+ if strict:
+ for key in state_dict.keys():
+ if key.startswith(prefix) and key != extra_state_key:
+ input_name = key[len(prefix) :]
+ input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
+ if input_name not in self._modules and input_name not in local_state:
+ unexpected_keys.append(key)
+
+ def resize_embedding_weight(self):
+ self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
+
+ def resize_embedding_bias(self):
+ self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
index 0838fcee6..e7679f0ec 100644
--- a/colossalai/shardformer/modeling/bert.py
+++ b/colossalai/shardformer/modeling/bert.py
@@ -1287,3 +1287,16 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
)
return forward
+
+
+def get_jit_fused_bert_intermediate_forward():
+ from transformers.models.bert.modeling_bert import BertIntermediate
+
+ from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
+
+ def forward(self: BertIntermediate, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, bias = self.dense(hidden_states)
+ hidden_states = JitGeLUFunction.apply(hidden_states, bias)
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py
index bd84c87c6..96e8a9d0c 100644
--- a/colossalai/shardformer/modeling/blip2.py
+++ b/colossalai/shardformer/modeling/blip2.py
@@ -129,3 +129,17 @@ def get_jit_fused_blip2_QFormer_output_forward():
return hidden_states
return forward
+
+
+def get_jit_fused_blip2_mlp_forward():
+ from transformers.models.blip_2.modeling_blip_2 import Blip2MLP
+
+ from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
+
+ def forward(self: Blip2MLP, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, bias = self.fc1(hidden_states)
+ hidden_states = JitGeLUFunction.apply(hidden_states, bias)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index fe70376e1..c4f326364 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -6,6 +6,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@@ -205,12 +206,13 @@ class BloomPipelineForwards:
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
# causal_mask is constructed every stage and its input is passed through different stages
- causal_mask = self._prepare_attn_mask(
+ causal_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape=(batch_size, seq_length),
+ inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)
-
+ causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config and shard_config.enable_sequence_parallelism:
@@ -227,21 +229,15 @@ class BloomPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
-
- return custom_forward
-
- outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ outputs = self._gradient_checkpointing_func(
+ block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
+ use_cache,
+ output_attentions,
)
else:
outputs = block(
@@ -1002,11 +998,13 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
- causal_mask = self._prepare_attn_mask(
+ causal_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape=(batch_size, seq_length),
+ inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)
+ causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
@@ -1018,21 +1016,15 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
-
- return custom_forward
-
- outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ outputs = self._gradient_checkpointing_func(
+ block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
+ use_cache,
+ output_attentions,
)
else:
outputs = block(
diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py
index 9207b34d0..53c151f02 100644
--- a/colossalai/shardformer/modeling/chatglm2.py
+++ b/colossalai/shardformer/modeling/chatglm2.py
@@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
-from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward():
@@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
device=query_layer.device,
)
temp_mask = (
- torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
+ torch.ones(
+ query_layer.shape[2],
+ key_layer.shape[2],
+ dtype=torch.bool,
+ device=query_layer.device,
+ )
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
)
@@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
+ scale=1.0 / self.norm_factor,
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
@@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
@staticmethod
def chatglm_model_forward(
- self: ChatGLMModel,
+ self: "ChatGLMModel",
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
@@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
- hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
@@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
- hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
+ hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
@staticmethod
def chatglm_for_conditional_generation_forward(
- self: ChatGLMForConditionalGeneration,
+ self: "ChatGLMForConditionalGeneration",
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py
index 4e271dfe0..df3b09c71 100644
--- a/colossalai/shardformer/modeling/falcon.py
+++ b/colossalai/shardformer/modeling/falcon.py
@@ -1,9 +1,16 @@
+import math
+import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@@ -99,11 +106,17 @@ def get_tp_falcon_decoder_layer_forward():
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
+ **kwargs,
):
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
residual = hidden_states
if self.config.new_decoder_architecture:
@@ -117,10 +130,12 @@ def get_tp_falcon_decoder_layer_forward():
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
+ position_ids=position_ids,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
+ **kwargs,
)
attention_output = attn_outputs[0]
@@ -154,87 +169,6 @@ def get_tp_falcon_decoder_layer_forward():
return forward
-def get_falcon_flash_attention_forward():
- try:
- from xformers.ops import memory_efficient_attention as me_attention
- except:
- raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
- from transformers.models.falcon.modeling_falcon import FalconAttention
-
- def forward(
- self: FalconAttention,
- hidden_states: torch.Tensor,
- alibi: Optional[torch.Tensor],
- attention_mask: torch.Tensor,
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- head_mask: Optional[torch.Tensor] = None,
- use_cache: bool = False,
- output_attentions: bool = False,
- ):
- fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
- num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
- # 3 x [batch_size, seq_length, num_heads, head_dim]
- (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
-
- batch_size, query_length, _, _ = query_layer.shape
-
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
- key_layer = key_layer.transpose(1, 2).reshape(
- batch_size * num_kv_heads,
- query_length,
- self.head_dim,
- )
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
-
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
-
- if layer_past is not None:
- past_key, past_value = layer_past
- # concatenate along seq_length dimension:
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
- key_layer = torch.cat((past_key, key_layer), dim=1)
- value_layer = torch.cat((past_value, value_layer), dim=1)
-
- _, kv_length, _ = key_layer.shape
- if use_cache:
- present = (key_layer, value_layer)
- else:
- present = None
-
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
-
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).contiguous()
- key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
- value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).contiguous()
-
- if alibi is not None:
- attention_mask_float = (
- attention_mask_float + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta
- )
-
- batch_size, src_len = query_layer_.size()[0], query_layer_.size()[1]
- tgt_len = key_layer_.size()[1]
- attention_mask_float = attention_mask_float.expand(batch_size, self.num_heads, src_len, tgt_len).contiguous()
- context_layer = me_attention(
- query_layer_,
- key_layer_,
- value_layer_,
- attn_bias=attention_mask_float,
- scale=self.inv_norm_factor,
- p=self.attention_dropout.p,
- )
- batch_size, seq_length, _, _ = context_layer.shape
- context_layer = context_layer.reshape(batch_size, seq_length, -1)
-
- output_tensor = self.dense(context_layer)
-
- return output_tensor, present
-
- return forward
-
-
class FalconPipelineForwards:
"""
This class serves as a micro library for falcon pipeline forwards.
@@ -246,6 +180,7 @@ class FalconPipelineForwards:
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -274,17 +209,6 @@ class FalconPipelineForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if past_key_values is None:
- past_key_values = tuple([None] * len(self.h))
- else:
- past_key_values = self._convert_to_rw_cache(past_key_values)
-
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape batch_size x num_heads x N x N
- # head_mask has shape n_layer x batch x num_heads x N x N
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
-
# case: First stage of training
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
@@ -295,16 +219,22 @@ class FalconPipelineForwards:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
-
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
-
hidden_states = inputs_embeds
-
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
@@ -312,22 +242,80 @@ class FalconPipelineForwards:
# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
- past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
- else:
- attention_mask = attention_mask.to(hidden_states.device)
+ past_key_values_length = past_key_values[0][0].shape[-2]
if self.use_alibi:
- alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
+ mask = (
+ torch.ones(
+ (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
+ )
+ if attention_mask is None
+ else attention_mask
+ )
+ alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0)
- causal_mask = self._prepare_attn_mask(
- attention_mask,
- input_shape=(batch_size, seq_length),
- past_key_values_length=past_key_values_length,
- )
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ if alibi is None:
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif head_mask is None:
+ alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
+
+ attention_mask_2d = attention_mask
+ # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ # We take care to integrate alibi bias in the attention_mask here.
+ if attention_mask_2d is None:
+ attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
+ else:
+ attention_mask = torch.masked_fill(
+ alibi / math.sqrt(self.config.hidden_size // self.num_heads),
+ attention_mask < -1,
+ torch.finfo(alibi.dtype).min,
+ )
+
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
+ if seq_length > 1:
+ attention_mask = AttentionMaskConverter._unmask_unattended(
+ attention_mask, attention_mask_2d, unmasked_value=0.0
+ )
+ else:
+ # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(
@@ -337,31 +325,23 @@ class FalconPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
-
- return custom_forward
-
- outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ outputs = self._gradient_checkpointing_func(
+ block.__call__,
hidden_states,
alibi,
- causal_mask,
+ attention_mask,
+ position_ids,
head_mask[i],
+ layer_past,
+ use_cache,
+ output_attentions,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
- attention_mask=causal_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
@@ -382,9 +362,6 @@ class FalconPipelineForwards:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
- if presents is not None:
- presents = self._convert_cache_to_standard_format(presents, batch_size)
-
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index 1306c8aa6..bfa995645 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
-from ..layer._operation import gather_forward_split_backward
logger = logging.get_logger(__name__)
@@ -178,11 +177,9 @@ class GPT2PipelineForwards:
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if stage_manager.is_first_stage():
- if position_ids is not None:
- position_ids = position_ids.view(-1, input_shape[-1])
- else:
+ if position_ids is None:
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
- position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+ position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
@@ -240,22 +237,16 @@ class GPT2PipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, use_cache, output_attentions)
-
- return custom_forward
-
- outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ outputs = self._gradient_checkpointing_func(
+ block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
+ use_cache,
+ output_attentions,
)
else:
outputs = block(
@@ -397,13 +388,11 @@ class GPT2PipelineForwards:
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
+ vocab_size=self.lm_head.out_features,
)
else:
loss = loss_fct(shift_logits, shift_labels)
- if not shard_config.parallel_output:
- lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
-
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
@@ -1301,12 +1290,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
- shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
+ shift_logits,
+ shift_labels,
+ process_group=shard_config.tensor_parallel_process_group,
+ vocab_size=self.lm_head.out_features,
)
- if not shard_config.parallel_output:
- lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
-
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
@@ -1321,3 +1310,18 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
)
return forward
+
+
+def get_jit_fused_gpt2_mlp_forward():
+ from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
+
+ from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
+
+ def forward(self: GPT2MLP, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states, bias = self.c_fc(hidden_states)
+ hidden_states = JitGeLUFunction.apply(hidden_states, bias)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py
index 5c254d1e7..4f4cec8bc 100644
--- a/colossalai/shardformer/modeling/gptj.py
+++ b/colossalai/shardformer/modeling/gptj.py
@@ -148,11 +148,9 @@ class GPTJPipelineForwards:
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# position id to be assigned not just for the first stage for attn input
- if position_ids is not None:
- position_ids = position_ids.view(-1, seq_length)
- else:
+ if position_ids is None:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
- position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+ position_ids = position_ids.unsqueeze(0)
if stage_manager.is_first_stage():
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
@@ -201,21 +199,15 @@ class GPTJPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, use_cache, output_attentions)
-
- return custom_forward
-
- outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ outputs = self._gradient_checkpointing_func(
+ block.__call__,
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
+ use_cache,
+ output_attentions,
)
else:
outputs = block(
@@ -627,7 +619,9 @@ def get_gptj_flash_attention_forward():
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
- present = (key, value)
+ # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
+ # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
+ present = (key.to(hidden_states.dtype), value)
else:
present = None
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 0f1b4ad0a..8a6a7cf17 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -7,6 +7,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -16,6 +17,8 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
@@ -31,13 +34,6 @@ from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
-try:
- from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
-
- LATEST_VERSION = True
-except ImportError:
- LATEST_VERSION = False
-
class LlamaPipelineForwards:
"""
@@ -75,13 +71,13 @@ class LlamaPipelineForwards:
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
+ batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
+ batch_size, seq_length, _ = inputs_embeds.shape[:2]
else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@@ -111,11 +107,12 @@ class LlamaPipelineForwards:
if position_ids is None:
position_ids = torch.arange(
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
)
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
- else:
- position_ids = position_ids.view(-1, seq_length).long()
+ position_ids = position_ids.unsqueeze(0)
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
@@ -123,20 +120,32 @@ class LlamaPipelineForwards:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
- mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
+ mask_shape,
+ hidden_states.dtype,
+ hidden_states.device,
+ q_padding_mask=attention_mask,
+ is_causal=True,
)
else:
- if attention_mask is None:
- attention_mask = torch.ones(
- (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
- )
- if LATEST_VERSION:
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
)
else:
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
)
if self.gradient_checkpointing and self.training:
@@ -149,7 +158,7 @@ class LlamaPipelineForwards:
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
+ next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
@@ -159,8 +168,10 @@ class LlamaPipelineForwards:
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
+ num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
- model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
+ model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
+ num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx
@@ -168,30 +179,22 @@ class LlamaPipelineForwards:
if output_hidden_states:
all_hidden_states += (hidden_states,)
- past_key_value = past_key_values[idx] if past_key_values is not None else None
-
if idx - start_idx < num_ckpt_layers:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, output_attentions, None)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(decoder_layer),
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
- None,
+ past_key_values,
+ output_attentions,
+ use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
- past_key_value=past_key_value,
+ past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
@@ -199,7 +202,7 @@ class LlamaPipelineForwards:
hidden_states = layer_outputs[0]
if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -212,7 +215,16 @@ class LlamaPipelineForwards:
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_cache,
+ all_hidden_states,
+ all_self_attns,
+ ]
+ if v is not None
+ )
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
@@ -316,7 +328,10 @@ class LlamaPipelineForwards:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
- shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
+ shift_logits,
+ shift_labels,
+ process_group=shard_config.tensor_parallel_process_group,
+ vocab_size=self.lm_head.out_features,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
@@ -455,23 +470,25 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
- llama_version = 2
try:
from transformers.models.llama.modeling_llama import repeat_kv
except:
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
- llama_version = 1
def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
bsz, q_len, _ = hidden_states.size()
if sp_mode in ["split_gather", "ring"]:
@@ -495,21 +512,23 @@ def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- past_key_value = (key_states, value_states) if use_cache else None
-
- # repeat k/v heads if n_kv_heads < n_heads
- if llama_version == 2:
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
@@ -570,7 +589,10 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
@@ -584,7 +606,11 @@ def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
- mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
+ mask_shape,
+ hidden_states.dtype,
+ hidden_states.device,
+ q_padding_mask=attention_mask,
+ is_causal=True,
)
if self.gradient_checkpointing and self.training:
@@ -735,11 +761,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
-
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
- shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
+ shift_logits,
+ shift_labels,
+ process_group=shard_config.tensor_parallel_process_group,
+ vocab_size=self.lm_head.out_features,
)
if not return_dict:
@@ -913,7 +941,10 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
@@ -929,10 +960,12 @@ def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
if attention_mask is None:
attention_mask = torch.ones(
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
)
- attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
)
diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py
index 0da1a35a0..d5f00fc9f 100644
--- a/colossalai/shardformer/modeling/mistral.py
+++ b/colossalai/shardformer/modeling/mistral.py
@@ -1,70 +1,608 @@
-from typing import Optional, Tuple
+import warnings
+from typing import List, Optional, Tuple, Union
import torch
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.shard import ShardConfig
+
+from ..layer import ColoAttention
+
+logger = logging.get_logger(__name__)
-def get_mistral_flash_attention_forward():
+class MistralForwards:
+ @staticmethod
+ def mistral_model_forward(
+ self: MistralModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ if use_cache:
+ logger.warning_once("use_cache=True is not supported for Mistral models at the moment.")
+ use_cache = False
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ past_key_values_length = 0
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if attention_mask is not None and self._use_flash_attention_2 and use_cache:
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+
+ if shard_config.enable_flash_attention:
+ # in this case, attention_mask is a dict rather than a tensor
+ mask_shape = (batch_size, 1, seq_length, seq_length)
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ mask_shape,
+ hidden_states.dtype,
+ hidden_states.device,
+ q_padding_mask=attention_mask,
+ is_causal=True,
+ )
+ else:
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ num_ckpt_layers = 0
+ if self.gradient_checkpointing and self.training:
+ num_ckpt_layers = end_idx - start_idx
+ # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
+ if shard_config.gradient_checkpoint_config is not None:
+ num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
+ stage=stage_manager.stage,
+ num_stages=stage_manager.num_stages,
+ num_layers=end_idx - start_idx,
+ model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
+ num_model_chunks=stage_manager.num_model_chunks,
+ )
+ assert num_ckpt_layers <= end_idx - start_idx
+
+ for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if idx - start_idx < num_ckpt_layers:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if stage_manager.is_last_stage():
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if stage_manager.is_last_stage():
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+ else:
+ return {"hidden_states": hidden_states}
+
+ @staticmethod
+ def mistral_for_causal_lm_forward(
+ self: MistralForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
+
+ >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = MistralForwards.mistral_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+
+ past_key_values = None
+
+ if stage_manager.is_last_stage():
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get("hidden_states")
+ return {"hidden_states": hidden_states}
+
+ @staticmethod
+ def mistral_for_sequence_classification_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = MistralForwards.mistral_model_forward(
+ self.model,
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ batch_size = inputs_embeds.shape[0]
+ else:
+ batch_size = hidden_states.shape[0]
+
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
+ logits.device
+ )
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+ else:
+ hidden_states = transformer_outputs.get("hidden_states")
+ return {"hidden_states": hidden_states}
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+def get_mistral_model_forward_for_flash_attn(shard_config: ShardConfig):
+ logger = logging.get_logger(__name__)
+ assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
+
+ def forward(
+ self: MistralModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = 0
+
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is not None and self._use_flash_attention_2 and use_cache:
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+ if shard_config.enable_flash_attention:
+ # in this case, attention_mask is a dict rather than a tensor
+ mask_shape = (batch_size, 1, seq_length, seq_length)
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ mask_shape,
+ inputs_embeds.dtype,
+ inputs_embeds.device,
+ q_padding_mask=attention_mask,
+ is_causal=True,
+ )
+ else:
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ return forward
+
+
+def get_mistral_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv
- from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
def forward(
self: MistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
+ **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
bsz, q_len, _ = hidden_states.size()
- assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = (
- self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- )
- value_states = (
- self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- )
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+ # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
- query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
- key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
- value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
+ assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
+ attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
- flash_attention_mask = None
- attn_mask_type = AttnMaskType.causal
- if attention_mask != None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
- flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
- attn_mask_type = AttnMaskType.paddedcausal
-
- attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
- attn_output = attention(
- query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
- )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index a26526430..81521c30b 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -42,7 +43,7 @@ def _get_attention_mask(
is_causal=True,
)
else:
- attention_mask = self.decoder._prepare_decoder_attention_mask(
+ attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
@@ -112,7 +113,7 @@ class OPTPipelineForwards:
inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
inputs_embeds.dtype
-
+ hidden_states = inputs_embeds
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for intermediate stages.")
@@ -125,12 +126,25 @@ class OPTPipelineForwards:
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
- if attention_mask is None:
- attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
- elif attention_mask.shape[1] != mask_seq_length:
- raise ValueError(
- f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
- f"{mask_seq_length} (sum of the lengths of current and past inputs)"
+ if self.decoder._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ attention_mask = (
+ torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+ if attention_mask is None
+ else attention_mask
+ )
+ else:
+ # 4d mask is passed through the layers
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+ elif attention_mask.shape[1] != mask_seq_length:
+ raise ValueError(
+ f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
+ f"{mask_seq_length} (sum of the lengths of current and past inputs)"
+ )
+ causal_attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, hidden_states, past_key_values_length
)
if stage_manager.is_first_stage():
@@ -205,20 +219,14 @@ class OPTPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if decoder.gradient_checkpointing and decoder.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, output_attentions, None)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(decoder_layer),
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
+ output_attentions,
+ use_cache,
)
else:
layer_outputs = decoder_layer(
diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py
index 9c5ce3fb6..b35bb6b94 100644
--- a/colossalai/shardformer/modeling/t5.py
+++ b/colossalai/shardformer/modeling/t5.py
@@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss
-from torch.utils.checkpoint import checkpoint
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -118,16 +117,13 @@ class T5PipelineForwards:
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
- if attention_mask is None:
- attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
- if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
- encoder_seq_length = encoder_hidden_states.shape[1]
- encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long)
-
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
@@ -138,7 +134,7 @@ class T5PipelineForwards:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
@@ -162,15 +158,8 @@ class T5PipelineForwards:
torch.cuda.set_device(hidden_states.device)
if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return tuple(module(*inputs, use_cache, output_attentions))
-
- return custom_forward
-
- layer_outputs = checkpoint(
- create_custom_forward(layer_module),
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.forward,
hidden_states,
extended_attention_mask,
position_bias,
@@ -180,6 +169,8 @@ class T5PipelineForwards:
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
+ use_cache,
+ output_attentions,
)
else:
layer_outputs = layer_module(
diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py
index e9c256a13..b1a5c4143 100644
--- a/colossalai/shardformer/modeling/vit.py
+++ b/colossalai/shardformer/modeling/vit.py
@@ -14,6 +14,8 @@ def _encoder_forward(
end_idx: int,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
return_dict: bool = True,
stage_manager: PipelineStageManager = None,
) -> Union[tuple, BaseModelOutput]:
@@ -23,20 +25,14 @@ def _encoder_forward(
layer_head_mask = head_mask[i] if head_mask is not None else None
if encoder.gradient_checkpointing and encoder.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs, False)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(layer_module),
+ layer_outputs = encoder._gradient_checkpointing_func(
+ layer_module.__call__,
hidden_states,
layer_head_mask,
+ output_attentions,
)
else:
- layer_outputs = layer_module(hidden_states, layer_head_mask, False)
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if not stage_manager.is_last_stage():
@@ -114,6 +110,8 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
end_idx=stage_index[1],
hidden_states=hidden_states,
head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
)
@@ -374,3 +372,15 @@ def get_jit_fused_vit_output_forward():
return hidden_states
return forward
+
+
+def get_jit_fused_vit_intermediate_forward():
+ from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states, bias = self.dense(hidden_states)
+ hidden_states = JitGeLUFunction.apply(hidden_states, bias)
+
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
index 7ccc79276..6d7df963a 100644
--- a/colossalai/shardformer/modeling/whisper.py
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -5,6 +5,10 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -35,6 +39,8 @@ def _get_attention_mask(
hidden_states: torch.Tensor,
past_key_values_length: int,
attention_mask: Optional[torch.FloatTensor],
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
):
batch_size, seq_length = hidden_states.shape[:2]
mask_seq_length = past_key_values_length + seq_length
@@ -47,12 +53,20 @@ def _get_attention_mask(
is_causal=True,
)
else:
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask,
- (batch_size, seq_length),
- hidden_states,
- past_key_values_length,
- )
+ input_shape = (batch_size, seq_length)
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and head_mask is None and not output_attentions:
+ # output_attentions=True & head_mask can not be supported when using SDPA.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask, input_shape, hidden_states, past_key_values_length
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, hidden_states, past_key_values_length
+ )
return attention_mask
@@ -539,18 +553,12 @@ class WhisperPipelineForwards:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs, output_attentions)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(encoder_layer),
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
+ output_attentions,
)
else:
layer_outputs = encoder_layer(
@@ -702,20 +710,16 @@ class WhisperPipelineForwards:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
+ attention_mask = _get_attention_mask(
+ self, shard_config, inputs_embeds, past_key_values_length, attention_mask
+ )
+
# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
- attention_mask = _get_attention_mask(
- self,
- shard_config,
- inputs_embeds,
- past_key_values_length,
- attention_mask,
- )
-
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -732,7 +736,6 @@ class WhisperPipelineForwards:
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder."
)
input_shape = hidden_states.size()[:-1]
-
attention_mask = _get_attention_mask(
self,
shard_config,
@@ -756,16 +759,8 @@ class WhisperPipelineForwards:
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, output_attentions, use_cache)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(decoder_layer),
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -773,6 +768,8 @@ class WhisperPipelineForwards:
head_mask[idx] if head_mask is not None else None,
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
None, # past_key_value
+ output_attentions,
+ use_cache,
)
else:
layer_outputs = decoder_layer(
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index 0991ace2c..d2b582af5 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -151,10 +151,10 @@ _POLICY_LIST = {
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
),
# ChatGLM
- "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
+ "transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMModelPolicy"
),
- "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
+ "transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Falcon
@@ -202,6 +202,13 @@ def _fullname(obj):
module = klass.__module__
if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str'
+ # patch custom models which are not in transformers
+ # it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
+ # or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
+ if module.startswith("transformers_modules"):
+ split_module = module.split(".")
+ if len(split_module) >= 2:
+ module = f"{split_module[0]}.{split_module[-1]}"
return module + "." + klass.__qualname__
@@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy:
if policy_location is None:
raise NotImplementedError(
- f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
+ f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py
index d67ab0a3c..282cf0464 100644
--- a/colossalai/shardformer/policies/base_policy.py
+++ b/colossalai/shardformer/policies/base_policy.py
@@ -28,6 +28,7 @@ class SubModuleReplacementDescription:
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
+
suffix: str
target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None
@@ -54,6 +55,7 @@ class ModulePolicyDescription:
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
+
attribute_replacement: Dict[str, Any] = None
param_replacement: List[Callable] = None
sub_module_replacement: List[SubModuleReplacementDescription] = None
@@ -195,3 +197,12 @@ class Policy(ABC):
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []
+
+ def tie_weight_check(self):
+ input_embedding = self.model.get_input_embeddings()
+ output_embedding = self.model.get_output_embeddings()
+ return (
+ input_embedding is not None
+ and output_embedding is not None
+ and id(input_embedding.weight) == id(output_embedding.weight)
+ )
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index 0a61d8cff..0c04f7d38 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -12,6 +12,7 @@ from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
+ get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
)
@@ -37,22 +38,14 @@ class BertPolicy(Policy):
pass
def preprocess(self):
- # reshape the embedding layer
- r"""
- Reshape the Embedding layer to make the embedding dimension divisible by world_size
- """
- # TODO:
- if self.shard_config.enable_tensor_parallelism:
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
+ self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
return self.model
def module_policy(self):
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
+ BertIntermediate,
BertLayer,
BertModel,
BertOutput,
@@ -62,6 +55,13 @@ class BertPolicy(Policy):
policy = {}
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = col_nn.VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = col_nn.PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@@ -79,6 +79,9 @@ class BertPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[BertLayer] = ModulePolicyDescription(
attribute_replacement={
"attention.self.all_head_size": self.model.config.hidden_size
@@ -134,6 +137,7 @@ class BertPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"overlap": overlap,
+ "skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
@@ -150,16 +154,20 @@ class BertPolicy(Policy):
policy[BertEmbeddings] = ModulePolicyDescription(
sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="word_embeddings",
- target_module=col_nn.VocabParallelEmbedding1D,
- ),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
]
)
+ if self.enable_bias_gelu_fused:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_jit_fused_bert_intermediate_forward(),
+ },
+ policy=policy,
+ target_key=BertIntermediate,
+ )
if sp_mode == "split_gather":
self.append_or_create_method_replacement(
@@ -168,6 +176,18 @@ class BertPolicy(Policy):
target_key=BertModel,
)
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="word_embeddings",
+ target_module=embedding_cls,
+ )
+ ],
+ policy=policy,
+ target_key=BertEmbeddings,
+ )
+
# optimization configuration
# Handle bert layer
self.append_or_create_submodule_replacement(
@@ -237,8 +257,21 @@ class BertPolicy(Policy):
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="decoder",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True},
+ target_module=col_nn.VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": True,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ },
+ ),
+ policy=base_policy,
+ target_key=BertLMPredictionHead,
+ )
+ else:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="decoder",
+ target_module=col_nn.PaddingLMHead,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=base_policy,
target_key=BertLMPredictionHead,
diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py
index 9be2a1e78..32d4edadb 100644
--- a/colossalai/shardformer/policies/blip2.py
+++ b/colossalai/shardformer/policies/blip2.py
@@ -3,6 +3,7 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.blip2 import (
forward_fn,
get_blip2_flash_attention_forward,
+ get_jit_fused_blip2_mlp_forward,
get_jit_fused_blip2_QFormer_output_forward,
get_jit_fused_blip2_QFormer_self_output_forward,
)
@@ -17,22 +18,17 @@ class BlipPolicy(Policy):
pass
def preprocess(self):
- # reshape the embedding layer
- r"""
- Reshape the Embedding layer to make the embedding dimension divisible by world_size
- """
- # TODO:
- vocab_size = self.model.config.qformer_config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
+ self.enable_bias_gelu_fused = (
+ self.shard_config.enable_jit_fused and self.model.config.vision_config.hidden_act == "gelu"
+ )
return self.model
def module_policy(self):
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Attention,
Blip2EncoderLayer,
+ Blip2MLP,
Blip2QFormerLayer,
Blip2QFormerModel,
Blip2QFormerOutput,
@@ -43,12 +39,22 @@ class BlipPolicy(Policy):
policy = {}
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = col_nn.VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = col_nn.PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.num_heads": self.model.config.vision_config.num_attention_heads
@@ -75,6 +81,7 @@ class BlipPolicy(Policy):
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.Linear1D_Col,
+ kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
@@ -202,22 +209,56 @@ class BlipPolicy(Policy):
],
)
- policy[OPTForCausalLM] = ModulePolicyDescription(
- sub_module_replacement=[
+ policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
+ if self.enable_bias_gelu_fused:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_jit_fused_blip2_mlp_forward(),
+ },
+ policy=policy,
+ target_key=Blip2MLP,
+ )
+
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=[
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
- target_module=col_nn.VocabParallelEmbedding1D,
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
- SubModuleReplacementDescription(
- suffix="lm_head",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True},
- ),
- ]
+ ],
+ policy=policy,
+ target_key=OPTForCausalLM,
)
- policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
-
+ if self.shard_config.enable_tensor_parallelism:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=col_nn.VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": True,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ },
+ ),
+ ],
+ policy=policy,
+ target_key=OPTForCausalLM,
+ )
+ else:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=col_nn.PaddingLMHead,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ ],
+ policy=policy,
+ target_key=OPTForCausalLM,
+ )
# optimization configuration
# Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement(
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index 2becadc3f..4f076d233 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -24,27 +24,12 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
class BloomPolicy(Policy):
def __init__(self) -> None:
super().__init__()
- import transformers
- from packaging.version import Version
-
- assert Version(transformers.__version__) <= Version(
- "4.33.0"
- ), "The Bloom model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self):
pass
def preprocess(self):
- # reshape the embedding layer
- r"""
- Reshape the Embedding layer to make the embedding dimension divisible by world_size
- """
- if self.shard_config.enable_tensor_parallelism:
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@@ -52,6 +37,13 @@ class BloomPolicy(Policy):
policy = {}
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = col_nn.VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = col_nn.PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@@ -69,6 +61,9 @@ class BloomPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
@@ -112,12 +107,19 @@ class BloomPolicy(Policy):
method_replacement={
"build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
},
- sub_module_replacement=[
+ )
+
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=[
SubModuleReplacementDescription(
suffix="word_embeddings",
- target_module=col_nn.VocabParallelEmbedding1D,
- )
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
],
+ policy=policy,
+ target_key=BloomModel,
)
# optimization configuration
@@ -282,7 +284,21 @@ class BloomForCausalLMPolicy(BloomPolicy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
- suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
+ suffix="lm_head",
+ target_module=col_nn.VocabParallelLMHead1D,
+ kwargs=dict(
+ gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
+ ),
+ ),
+ policy=policy,
+ target_key=BloomForCausalLM,
+ )
+ else:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=col_nn.PaddingLMHead,
+ kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
),
policy=policy,
target_key=BloomForCausalLM,
diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py
index dabc14bff..4baf89f6a 100644
--- a/colossalai/shardformer/policies/chatglm2.py
+++ b/colossalai/shardformer/policies/chatglm2.py
@@ -7,7 +7,6 @@ from torch import Tensor
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
-from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_forward_fn,
@@ -17,7 +16,11 @@ from ..modeling.chatglm2 import (
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
-__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"]
+__all__ = [
+ "ChatGLMPolicy",
+ "ChatGLMModelPolicy",
+ "ChatGLMForConditionalGenerationPolicy",
+]
class ChatGLMPolicy(Policy):
@@ -25,27 +28,24 @@ class ChatGLMPolicy(Policy):
pass
def preprocess(self):
- # Resize embedding
- if self.shard_config.enable_tensor_parallelism:
- vocab_size = self.model.config.padded_vocab_size
- world_size = self.shard_config.tensor_parallel_size
-
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
-
if self.pipeline_stage_manager is not None:
# the batch_size_dim is bounded to Model
bsz_dim = 1
setattr(self.model, "batch_size_dim", bsz_dim)
+ self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
-
policy = {}
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = col_nn.VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = col_nn.PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
if self.model.config.rmsnorm:
norm_cls = col_nn.FusedRMSNorm
@@ -68,17 +68,27 @@ class ChatGLMPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
- policy[ChatGLMModel] = ModulePolicyDescription(
- attribute_replacement={},
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="embedding.word_embeddings",
- target_module=col_nn.VocabParallelEmbedding1D,
- )
- ],
- )
-
- policy[GLMBlock] = ModulePolicyDescription(
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
+ attn_kwargs = {
+ "self_attention.qkv_hidden_size": (
+ self.model.config.kv_channels * self.model.config.num_attention_heads * 3
+ )
+ // self.shard_config.tensor_parallel_size,
+ }
+ if self.model.config.multi_query_attention:
+ assert (
+ self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0
+ ), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
+ attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = (
+ self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
+ )
+ attn_kwargs["self_attention.qkv_hidden_size"] = (
+ self.model.config.kv_channels * self.model.config.num_attention_heads
+ + 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num
+ ) // self.shard_config.tensor_parallel_size
+ policy["GLMBlock"] = ModulePolicyDescription(
attribute_replacement={
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
@@ -86,22 +96,23 @@ class ChatGLMPolicy(Policy):
self.model.config.kv_channels * self.model.config.num_attention_heads
)
// self.shard_config.tensor_parallel_size,
- "self_attention.qkv_hidden_size": (
- self.model.config.kv_channels * self.model.config.num_attention_heads * 3
- )
- // self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
* self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
+ **attn_kwargs,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
+ kwargs={
+ "seq_parallel_mode": sp_mode,
+ "seq_parallel_dim": 0,
+ "overlap": overlap,
+ },
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
@@ -114,6 +125,19 @@ class ChatGLMPolicy(Policy):
),
],
)
+
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="embedding.word_embeddings",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ ],
+ policy=policy,
+ target_key="ChatGLMModel",
+ )
# optimization configuration
self.append_or_create_submodule_replacement(
description=[
@@ -129,7 +153,7 @@ class ChatGLMPolicy(Policy):
),
],
policy=policy,
- target_key=GLMBlock,
+ target_key="GLMBlock",
)
if self.model.config.post_layer_norm:
@@ -141,7 +165,7 @@ class ChatGLMPolicy(Policy):
)
],
policy=policy,
- target_key=ChatGLMModel,
+ target_key="ChatGLMModel",
)
# use flash attention
@@ -151,7 +175,7 @@ class ChatGLMPolicy(Policy):
"forward": get_flash_core_attention_forward(),
},
policy=policy,
- target_key=CoreAttention,
+ target_key="CoreAttention",
)
# use sequence parallel
@@ -159,7 +183,7 @@ class ChatGLMPolicy(Policy):
self.append_or_create_method_replacement(
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
- target_key=ChatGLMModel,
+ target_key="ChatGLMModel",
)
# use jit fused operator
@@ -170,7 +194,7 @@ class ChatGLMPolicy(Policy):
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
- target_key=GLMBlock,
+ target_key="GLMBlock",
)
return policy
@@ -218,7 +242,10 @@ class ChatGLMPolicy(Policy):
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
- new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
+ new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config,
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -232,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
- model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy
+ model_cls="ChatGLMModel",
+ new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
+ policy=policy,
)
return policy
@@ -250,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
- model_cls=ChatGLMForConditionalGeneration,
+ model_cls="ChatGLMForConditionalGeneration",
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
policy=policy,
)
diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py
index fe61c406f..23d6efbeb 100644
--- a/colossalai/shardformer/policies/falcon.py
+++ b/colossalai/shardformer/policies/falcon.py
@@ -7,12 +7,7 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn
-from ..modeling.falcon import (
- FalconPipelineForwards,
- build_falcon_alibi_tensor_fn,
- get_falcon_flash_attention_forward,
- get_tp_falcon_decoder_layer_forward,
-)
+from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["FalconPolicy"]
@@ -21,31 +16,16 @@ __all__ = ["FalconPolicy"]
class FalconPolicy(Policy):
def __init__(self) -> None:
super().__init__()
- import transformers
- from packaging.version import Version
-
- assert Version(transformers.__version__) <= Version(
- "4.33.0"
- ), "The Falcon model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self):
pass
def preprocess(self):
- # reshape the embedding layer
- r"""
- Reshape the Embedding layer to make the embedding dimension divisible by world_size
- """
- if self.shard_config.enable_tensor_parallelism:
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
- from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel
+ from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
if not self.model.config.new_decoder_architecture and self.model.config.multi_query:
warnings.warn(
@@ -58,7 +38,21 @@ class FalconPolicy(Policy):
warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
policy = {}
+
+ embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = col_nn.VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = col_nn.PaddingEmbedding
+
+ if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
+ assert (
+ self.model.config.num_kv_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
attn_attribute_replacement = {
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@@ -98,12 +92,19 @@ class FalconPolicy(Policy):
method_replacement={
"build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
},
- sub_module_replacement=[
+ )
+
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=[
SubModuleReplacementDescription(
suffix="word_embeddings",
- target_module=col_nn.VocabParallelEmbedding1D,
- )
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
],
+ policy=policy,
+ target_key=FalconModel,
)
# optimization configuration
@@ -141,11 +142,8 @@ class FalconPolicy(Policy):
)
if self.shard_config.enable_flash_attention:
- self.append_or_create_method_replacement(
- description={"forward": get_falcon_flash_attention_forward()},
- policy=policy,
- target_key=FalconAttention,
- )
+ warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.")
+
return policy
def postprocess(self):
@@ -232,11 +230,26 @@ class FalconForCausalLMPolicy(FalconPolicy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
- suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
+ suffix="lm_head",
+ target_module=col_nn.VocabParallelLMHead1D,
+ kwargs=dict(
+ gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
+ ),
),
policy=policy,
target_key=FalconForCausalLM,
)
+ else:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=col_nn.PaddingLMHead,
+ kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
+ ),
+ policy=policy,
+ target_key=FalconForCausalLM,
+ )
+
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=FalconForCausalLM,
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index 380a432dc..281ea88c2 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -10,6 +10,7 @@ from ..modeling.gpt2 import (
GPT2PipelineForwards,
get_gpt2_flash_attention_forward,
get_gpt_model_forward_for_flash_attn,
+ get_jit_fused_gpt2_mlp_forward,
get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn,
)
@@ -34,19 +35,31 @@ class GPT2Policy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
- if self.shard_config.enable_tensor_parallelism:
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
+ self.enable_bias_gelu_fused = (
+ self.shard_config.enable_jit_fused and self.model.config.activation_function == "gelu"
+ )
return self.model
def module_policy(self):
- from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
+ from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
+
+ ATTN_IMPLEMENTATION = {
+ "eager": GPT2Attention,
+ }
policy = {}
+ attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
+
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = col_nn.VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = col_nn.PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@@ -71,12 +84,11 @@ class GPT2Policy(Policy):
self.shard_config.enable_flash_attention = False
use_flash_attention = False
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="wte",
- target_module=col_nn.VocabParallelEmbedding1D,
- ),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
@@ -114,6 +126,7 @@ class GPT2Policy(Policy):
"n_fused": 1,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
+ "skip_bias_add": self.enable_bias_gelu_fused,
},
),
SubModuleReplacementDescription(
@@ -137,6 +150,25 @@ class GPT2Policy(Policy):
),
],
)
+ if self.enable_bias_gelu_fused:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_jit_fused_gpt2_mlp_forward(),
+ },
+ policy=policy,
+ target_key=GPT2MLP,
+ )
+ if embedding_cls is not None:
+ # padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="wte",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key=GPT2Model,
+ )
# optimization configuration
self.append_or_create_submodule_replacement(
@@ -177,7 +209,7 @@ class GPT2Policy(Policy):
"forward": get_gpt2_flash_attention_forward(),
},
policy=policy,
- target_key=GPT2Attention,
+ target_key=attn_cls,
)
if not self.shard_config.pipeline_stage_manager:
policy[GPT2Model].method_replacement = {
@@ -298,8 +330,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": not self.shard_config.parallel_output},
+ target_module=col_nn.VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": False,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ },
)
],
)
@@ -308,7 +343,19 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
- module_policy.update(addon_module)
+ else:
+ addon_module = {
+ GPT2LMHeadModel: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=col_nn.PaddingLMHead,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ )
+ ]
+ )
+ }
+ module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
@@ -353,13 +400,28 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True},
+ target_module=col_nn.VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": True,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ },
)
]
)
}
- module_policy.update(addon_module)
+ else:
+ addon_module = {
+ GPT2DoubleHeadsModel: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=col_nn.PaddingLMHead,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ )
+ ]
+ )
+ }
+ module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py
index eab4c214a..3315eb1e9 100644
--- a/colossalai/shardformer/policies/gptj.py
+++ b/colossalai/shardformer/policies/gptj.py
@@ -29,35 +29,39 @@ class GPTJPolicy(Policy):
pass
def preprocess(self):
- # reshape the embedding layer
- r"""
- Reshape the Embedding layer to make the embedding dimension divisible by world_size
- """
- if self.shard_config.enable_tensor_parallelism:
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
return self.model
def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
+ ATTN_IMPLEMENTATION = {
+ "eager": GPTJAttention,
+ }
+
policy = {}
+
+ attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
+
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = col_nn.VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = col_nn.PaddingEmbedding
+
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
- use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[GPTJModel] = ModulePolicyDescription(
sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="wte",
- target_module=col_nn.VocabParallelEmbedding1D,
- ),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
@@ -76,7 +80,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
- "seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
@@ -84,7 +87,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
- "seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
@@ -92,24 +94,20 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
- "seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
target_module=col_nn.Linear1D_Row,
- kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.Linear1D_Col,
- kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_out",
target_module=col_nn.Linear1D_Row,
- kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
@@ -126,6 +124,17 @@ class GPTJPolicy(Policy):
],
)
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="wte",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key=GPTJModel,
+ )
+
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
@@ -154,7 +163,7 @@ class GPTJPolicy(Policy):
"forward": get_gptj_flash_attention_forward(),
},
policy=policy,
- target_key=GPTJAttention,
+ target_key=attn_cls,
)
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
@@ -255,13 +264,28 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True},
+ target_module=col_nn.VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": True,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ },
)
]
)
}
- policy.update(addon_module)
+ else:
+ addon_module = {
+ GPTJForCausalLM: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=col_nn.PaddingLMHead,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ )
+ ]
+ )
+ }
+ policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index bb4551b2c..6e541f792 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -6,7 +6,16 @@ import torch.nn as nn
from torch import Tensor
from torch.nn import Module
-from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
+from colossalai.shardformer.layer import (
+ FusedRMSNorm,
+ Linear1D_Col,
+ Linear1D_Row,
+ PaddingEmbedding,
+ PaddingLMHead,
+ RMSNorm,
+ VocabParallelEmbedding1D,
+ VocabParallelLMHead1D,
+)
from ..modeling.llama import (
LlamaPipelineForwards,
@@ -26,22 +35,34 @@ class LlamaPolicy(Policy):
pass
def preprocess(self):
- if self.shard_config.enable_tensor_parallelism:
- # Resize embedding
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
-
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
-
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
+ from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaFlashAttention2,
+ LlamaModel,
+ LlamaSdpaAttention,
+ )
+ ATTN_IMPLEMENTATION = {
+ "eager": LlamaAttention,
+ "flash_attention_2": LlamaFlashAttention2,
+ "sdpa": LlamaSdpaAttention,
+ }
policy = {}
+ attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
@@ -85,7 +106,7 @@ class LlamaPolicy(Policy):
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
},
policy=policy,
- target_key=LlamaAttention,
+ target_key=attn_cls,
)
elif sp_mode == "all_to_all":
decoder_attribute_replacement = {
@@ -94,7 +115,7 @@ class LlamaPolicy(Policy):
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
- policy[LlamaAttention] = ModulePolicyDescription(
+ policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
self.append_or_create_method_replacement(
@@ -102,7 +123,7 @@ class LlamaPolicy(Policy):
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
},
policy=policy,
- target_key=LlamaAttention,
+ target_key=attn_cls,
)
self.append_or_create_method_replacement(
description={
@@ -117,6 +138,12 @@ class LlamaPolicy(Policy):
)
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
+ assert (
+ self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
@@ -167,10 +194,12 @@ class LlamaPolicy(Policy):
],
)
+ if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
- target_module=VocabParallelEmbedding1D,
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=LlamaModel,
@@ -211,7 +240,7 @@ class LlamaPolicy(Policy):
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
},
policy=policy,
- target_key=LlamaAttention,
+ target_key=attn_cls,
)
if self.pipeline_stage_manager is None:
# replace llama model forward method
@@ -327,8 +356,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
- target_module=Linear1D_Col,
- kwargs={"gather_output": not self.shard_config.parallel_output},
+ target_module=VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": not self.shard_config.parallel_output,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ },
)
],
)
@@ -337,7 +369,19 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
- policy.update(new_item)
+ else:
+ new_item = {
+ LlamaForCausalLM: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=PaddingLMHead,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ )
+ ],
+ )
+ }
+ policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py
index c0b8b3375..984b71646 100644
--- a/colossalai/shardformer/policies/mistral.py
+++ b/colossalai/shardformer/policies/mistral.py
@@ -1,11 +1,26 @@
import warnings
-from typing import Dict, Union
+from functools import partial
+from typing import Callable, Dict, List, Union
import torch.nn as nn
+from torch import Tensor
+from torch.nn import Module
-from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
+from colossalai.shardformer.layer import (
+ FusedRMSNorm,
+ Linear1D_Col,
+ Linear1D_Row,
+ PaddingEmbedding,
+ PaddingLMHead,
+ VocabParallelEmbedding1D,
+ VocabParallelLMHead1D,
+)
-from ..modeling.mistral import get_mistral_flash_attention_forward
+from ..modeling.mistral import (
+ MistralForwards,
+ get_mistral_flash_attention_forward,
+ get_mistral_model_forward_for_flash_attn,
+)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"]
@@ -16,22 +31,34 @@ class MistralPolicy(Policy):
pass
def preprocess(self):
- if self.shard_config.enable_tensor_parallelism:
- # Resize embedding
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
-
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
-
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel
+ from transformers.models.mistral.modeling_mistral import (
+ MistralAttention,
+ MistralDecoderLayer,
+ MistralFlashAttention2,
+ MistralModel,
+ )
+
+ ATTN_IMPLEMENTATION = {
+ "eager": MistralAttention,
+ "flash_attention_2": MistralFlashAttention2,
+ }
policy = {}
+ attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
+
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
@@ -39,6 +66,12 @@ class MistralPolicy(Policy):
)
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
+ assert (
+ self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
@@ -80,10 +113,12 @@ class MistralPolicy(Policy):
],
)
+ if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
- target_module=VocabParallelEmbedding1D,
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=MistralModel,
@@ -118,27 +153,112 @@ class MistralPolicy(Policy):
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
- "forward": get_mistral_flash_attention_forward(),
+ "forward": get_mistral_flash_attention_forward(self.shard_config),
},
policy=policy,
- target_key=MistralAttention,
+ target_key=attn_cls,
)
+ if self.pipeline_stage_manager is None:
+ # replace llama model forward method
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_mistral_model_forward_for_flash_attn(self.shard_config),
+ },
+ policy=policy,
+ target_key=MistralModel,
+ )
return policy
def postprocess(self):
return self.model
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager is None:
+ return
+
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "MistralModel":
+ module = self.model
+ else:
+ module = self.model.model
+
+ if stage_manager.is_interleave:
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
+ method_replacement = {
+ "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
+ }
+
+ else:
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
+ method_replacement = {
+ "forward": partial(
+ new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
+ )
+ }
+
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "MistralModel":
+ module = self.model
+ else:
+ module = self.model.model
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ if stage_manager.is_interleave:
+ assert stage_manager.num_model_chunks is not None
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_indices = stage_manager.get_stage_index(layers_per_stage)
+ if stage_manager.is_first_stage(ignore_chunk=True):
+ held_layers.append(module.embed_tokens)
+ for start_idx, end_idx in stage_indices:
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage(ignore_chunk=True):
+ held_layers.append(module.norm)
+
+ else:
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+ return held_layers
+
class MistralModelPolicy(MistralPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
- if self.pipeline_stage_manager:
- warnings.warn("Mistral doesn't support pipeline parallelism now.")
+ policy = super().module_policy()
+ from transformers.models.mistral.modeling_mistral import MistralModel
- return super().module_policy()
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(
+ model_cls=MistralModel, new_forward=MistralForwards.mistral_model_forward, policy=policy
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in mistral model"""
+ return []
class MistralForCausalLMPolicy(MistralPolicy):
@@ -153,19 +273,63 @@ class MistralForCausalLMPolicy(MistralPolicy):
MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
- suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+ suffix="lm_head",
+ target_module=VocabParallelLMHead1D,
+ kwargs=dict(
+ gather_output=True,
+ make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
+ ),
+ )
+ ]
+ )
+ }
+ else:
+ new_item = {
+ MistralForCausalLM: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=PaddingLMHead,
+ kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
)
]
)
}
- if self.pipeline_stage_manager:
- warnings.warn("Mistral doesn't support pipeline parallelism now.")
+ policy.update(new_item)
- policy.update(new_item)
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=MistralForCausalLM, new_forward=MistralForwards.mistral_for_causal_lm_forward, policy=policy
+ )
return policy
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage(ignore_chunk=True):
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ mistral_model = self.model.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if (
+ id(mistral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
+ and self.pipeline_stage_manager.num_stages > 1
+ ):
+ # tie weights
+ return [
+ {
+ 0: mistral_model.embed_tokens.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
+ }
+ ]
+ return []
+
class MistralForSequenceClassificationPolicy(MistralPolicy):
def module_policy(self):
@@ -184,9 +348,26 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
]
)
}
-
- if self.pipeline_stage_manager:
- warnings.warn("Mistral doesn't support pipeline parallelism now.")
-
policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=MistralForSequenceClassification,
+ new_forward=MistralForwards.mistral_for_sequence_classification_forward,
+ policy=policy,
+ )
+
return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage(ignore_chunk=True):
+ held_layers.append(self.model.score)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama for sequence classification model"""
+ return []
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index 98e584be8..9619b3d41 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -5,7 +5,16 @@ from typing import Callable, Dict, List
import torch.nn as nn
from torch import Tensor, nn
-from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
+from colossalai.shardformer.layer import (
+ FusedLayerNorm,
+ LayerNorm,
+ Linear1D_Col,
+ Linear1D_Row,
+ PaddingEmbedding,
+ PaddingLMHead,
+ VocabParallelEmbedding1D,
+ VocabParallelLMHead1D,
+)
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
@@ -29,35 +38,34 @@ __all__ = [
class OPTPolicy(Policy):
def __init__(self) -> None:
super().__init__()
- import transformers
- from packaging.version import Version
-
- # TODO: remove this version check when transformers>=4.36.0
- assert Version(transformers.__version__) <= Version(
- "4.33.0"
- ), "The OPT model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self):
pass
def preprocess(self):
- # reshape the embedding layer
- r"""
- Reshape the Embedding layer to make the embedding dimension divisible by world_size
- """
- if self.shard_config.enable_tensor_parallelism:
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
return self.model
def module_policy(self):
- from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
+ from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2
+
+ ATTN_IMPLEMENTATION = {
+ "eager": OPTAttention,
+ "flash_attention_2": OptFlashAttention2,
+ }
policy = {}
+ attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
+
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
@@ -68,14 +76,9 @@ class OPTPolicy(Policy):
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
- policy[OPTDecoder] = ModulePolicyDescription(
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="embed_tokens",
- target_module=VocabParallelEmbedding1D,
- )
- ]
- )
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[OPTDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
@@ -89,7 +92,7 @@ class OPTPolicy(Policy):
]
)
- policy[OPTAttention] = ModulePolicyDescription(
+ policy[attn_cls] = ModulePolicyDescription(
attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
@@ -114,6 +117,17 @@ class OPTPolicy(Policy):
],
)
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key=OPTDecoder,
+ )
+
# optimization configuration
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@@ -148,7 +162,7 @@ class OPTPolicy(Policy):
"forward": get_opt_flash_attention_forward(self.shard_config),
},
policy=policy,
- target_key=OPTAttention,
+ target_key=attn_cls,
)
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
@@ -253,8 +267,20 @@ class OPTForCausalLMPolicy(OPTPolicy):
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
- target_module=Linear1D_Col,
- kwargs=dict(gather_output=True),
+ target_module=VocabParallelLMHead1D,
+ kwargs=dict(
+ gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
+ ),
+ ),
+ policy=policy,
+ target_key=OPTForCausalLM,
+ )
+ else:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=PaddingLMHead,
+ kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
),
policy=policy,
target_key=OPTForCausalLM,
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
index 498e62164..c224d7769 100644
--- a/colossalai/shardformer/policies/sam.py
+++ b/colossalai/shardformer/policies/sam.py
@@ -1,6 +1,8 @@
+import warnings
+
import colossalai.shardformer.layer as col_nn
-from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
+from ..modeling.sam import forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["SamPolicy", "SamModelPolicy"]
@@ -15,7 +17,6 @@ class SamPolicy(Policy):
def module_policy(self):
from transformers.models.sam.modeling_sam import (
- SamAttention,
SamTwoWayAttentionBlock,
SamTwoWayTransformer,
SamVisionAttention,
@@ -30,6 +31,9 @@ class SamPolicy(Policy):
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[SamVisionLayer] = ModulePolicyDescription(
attribute_replacement={
"attn.num_attention_heads": self.model.config.vision_config.num_attention_heads
@@ -210,20 +214,21 @@ class SamPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
- self.append_or_create_method_replacement(
- description={
- "forward": get_sam_flash_attention_forward(),
- },
- policy=policy,
- target_key=SamAttention,
- )
- self.append_or_create_method_replacement(
- description={
- "forward": get_sam_vision_flash_attention_forward(),
- },
- policy=policy,
- target_key=SamVisionAttention,
- )
+ warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.")
+ # self.append_or_create_method_replacement(
+ # description={
+ # "forward": get_sam_flash_attention_forward(),
+ # },
+ # policy=policy,
+ # target_key=SamAttention,
+ # )
+ # self.append_or_create_method_replacement(
+ # description={
+ # "forward": get_sam_vision_flash_attention_forward(),
+ # },
+ # policy=policy,
+ # target_key=SamVisionAttention,
+ # )
return policy
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 0c8ec15fa..1298f0af3 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -13,8 +13,11 @@ from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
+ PaddingEmbedding,
+ PaddingLMHead,
RMSNorm,
VocabParallelEmbedding1D,
+ VocabParallelLMHead1D,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
@@ -36,16 +39,7 @@ class T5BasePolicy(Policy):
pass
def preprocess(self):
- # reshape the embedding layer
- r"""
- Reshape the Embedding layer to make the embedding dimension divisible by world_size
- """
- if self.shard_config.enable_tensor_parallelism:
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@@ -61,6 +55,13 @@ class T5BasePolicy(Policy):
policy = {}
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
@@ -71,16 +72,15 @@ class T5BasePolicy(Policy):
warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[T5Stack] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
- SubModuleReplacementDescription(
- suffix="embed_tokens",
- target_module=VocabParallelEmbedding1D,
- ),
]
)
policy[T5LayerSelfAttention] = ModulePolicyDescription(
@@ -176,6 +176,17 @@ class T5BasePolicy(Policy):
]
)
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key=T5Stack,
+ )
+
# optimization configuration
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@@ -370,11 +381,19 @@ class T5ModelPolicy(T5BasePolicy):
policy = super().module_policy()
+ embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
+ if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="shared",
- target_module=VocabParallelEmbedding1D,
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=T5Model,
@@ -406,17 +425,44 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
policy = super().module_policy()
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="shared",
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
+ policy=policy,
+ target_key=T5ForConditionalGeneration,
+ )
+
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="shared",
- target_module=VocabParallelEmbedding1D,
- ),
- SubModuleReplacementDescription(
- suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
- ),
- ],
+ description=SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": True,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ },
+ ),
+ policy=policy,
+ target_key=T5ForConditionalGeneration,
+ )
+ else:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=PaddingLMHead,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+ ),
policy=policy,
target_key=T5ForConditionalGeneration,
)
@@ -467,11 +513,19 @@ class T5EncoderPolicy(T5BasePolicy):
policy = super().module_policy()
+ embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+
+ if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="shared",
- target_module=VocabParallelEmbedding1D,
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=T5EncoderModel,
diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py
index 905398c4d..069ad0c26 100644
--- a/colossalai/shardformer/policies/vit.py
+++ b/colossalai/shardformer/policies/vit.py
@@ -11,6 +11,7 @@ from ..modeling.vit import (
ViTForImageClassification_pipeline_forward,
ViTForMaskedImageModeling_pipeline_forward,
ViTModel_pipeline_forward,
+ get_jit_fused_vit_intermediate_forward,
get_jit_fused_vit_output_forward,
get_vit_flash_self_attention_forward,
)
@@ -24,10 +25,17 @@ class ViTPolicy(Policy):
pass
def preprocess(self):
+ self.enable_bias_gelu_fused = self.shard_config.enable_jit_fused and self.model.config.hidden_act == "gelu"
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention
+ from transformers.models.vit.modeling_vit import (
+ ViTEmbeddings,
+ ViTIntermediate,
+ ViTLayer,
+ ViTOutput,
+ ViTSelfAttention,
+ )
policy = {}
@@ -36,6 +44,9 @@ class ViTPolicy(Policy):
warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[ViTEmbeddings] = ModulePolicyDescription(
attribute_replacement={},
param_replacement=[],
@@ -83,6 +94,9 @@ class ViTPolicy(Policy):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "skip_bias_add": self.enable_bias_gelu_fused,
+ },
),
SubModuleReplacementDescription(
suffix="output.dense",
@@ -94,6 +108,14 @@ class ViTPolicy(Policy):
),
],
)
+ if self.enable_bias_gelu_fused:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_jit_fused_vit_intermediate_forward(),
+ },
+ policy=policy,
+ target_key=ViTIntermediate,
+ )
# use flash attention
if self.shard_config.enable_flash_attention:
@@ -115,6 +137,7 @@ class ViTPolicy(Policy):
policy=policy,
target_key=ViTOutput,
)
+
return policy
def new_model_class(self):
diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py
index c63f6d1cc..441e512bb 100644
--- a/colossalai/shardformer/policies/whisper.py
+++ b/colossalai/shardformer/policies/whisper.py
@@ -29,13 +29,6 @@ __all__ = [
class WhisperPolicy(Policy):
def __init__(self) -> None:
super().__init__()
- import transformers
- from packaging.version import Version
-
- # TODO: remove this version check when transformers>=4.36.0
- assert Version(transformers.__version__) <= Version(
- "4.33.0"
- ), "The Whisper model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self):
pass
@@ -45,11 +38,7 @@ class WhisperPolicy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@@ -59,10 +48,19 @@ class WhisperPolicy(Policy):
WhisperDecoderLayer,
WhisperEncoder,
WhisperEncoderLayer,
+ WhisperFlashAttention2,
+ WhisperSdpaAttention,
)
policy = {}
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = col_nn.VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = col_nn.PaddingEmbedding
+
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@@ -80,6 +78,9 @@ class WhisperPolicy(Policy):
warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.")
if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.encoder_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
policy[WhisperEncoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
@@ -167,13 +168,17 @@ class WhisperPolicy(Policy):
],
)
- policy[WhisperDecoder] = ModulePolicyDescription(
- sub_module_replacement=[
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=[
SubModuleReplacementDescription(
suffix="embed_tokens",
- target_module=col_nn.VocabParallelEmbedding1D,
+ target_module=embedding_cls,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
- ]
+ ],
+ policy=policy,
+ target_key=WhisperDecoder,
)
# optimization configuration
@@ -242,6 +247,20 @@ class WhisperPolicy(Policy):
policy=policy,
target_key=WhisperAttention,
)
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_whisper_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=WhisperFlashAttention2,
+ )
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_whisper_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=WhisperSdpaAttention,
+ )
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
description={
@@ -280,8 +299,21 @@ class WhisperPolicy(Policy):
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="proj_out",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True},
+ target_module=col_nn.VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": True,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ },
+ ),
+ policy=base_policy,
+ target_key=WhisperForConditionalGeneration,
+ )
+ else:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="proj_out",
+ target_module=col_nn.PaddingLMHead,
+ kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=base_policy,
target_key=WhisperForConditionalGeneration,
@@ -526,9 +558,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
- def preprocess(self):
- return self.model
-
def module_policy(self):
from transformers import WhisperForAudioClassification
diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py
index 9c6c2b54e..9167da795 100644
--- a/colossalai/shardformer/shard/grad_ckpt_config.py
+++ b/colossalai/shardformer/shard/grad_ckpt_config.py
@@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
"""
+
"""
Args:
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
@@ -46,26 +47,15 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
...
"""
- num_stages: Optional[int] = None
- num_model_chunks: Optional[int] = None
- num_model_layers: Optional[int] = None
num_ckpt_layers_per_stage: Optional[List[int]] = None
def __post_init__(self):
- if self._enable_gradient_checkpointing_ratio:
+ if self._enable_customized_ckpt_layers_per_stage:
+ assert all([num_ckpt_layers >= 0 for num_ckpt_layers in self.num_ckpt_layers_per_stage])
+ elif self._enable_gradient_checkpointing_ratio:
if not (0 <= self.gradient_checkpointing_ratio <= 1):
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
- if self._enable_customized_ckpt_layers_per_stage:
- assert (
- self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
- )
- assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
- assert all(
- [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
- )
- self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
-
@property
def _enable_gradient_checkpointing_ratio(self) -> bool:
return self.gradient_checkpointing_ratio is not None
@@ -74,13 +64,16 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
return self.num_ckpt_layers_per_stage is not None
- def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
+ def get_num_ckpt_layers(
+ self, stage: int, num_stages: int, num_layers: int, model_chunk_id: int = 0, num_model_chunks: int = 1
+ ) -> int:
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
raise RuntimeError("No checkpointed layers information is provided")
if self._enable_customized_ckpt_layers_per_stage:
- assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
- num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
+ assert len(self.num_ckpt_layers_per_stage) == num_stages * num_model_chunks
+ assert stage <= num_stages and model_chunk_id <= num_model_chunks
+ num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * num_stages]
assert num_ckpt_layers <= num_layers
return num_ckpt_layers
else:
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 2d3f6620f..453e8d23e 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -30,6 +30,7 @@ class ShardConfig:
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
"""
+
tensor_parallel_process_group: Optional[ProcessGroup] = None
sequence_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
@@ -42,10 +43,9 @@ class ShardConfig:
sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False
parallel_output: bool = True
+ make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
- # TODO padding vocab
- # make_vocab_size_divisible_by: int = 128
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@@ -110,7 +110,15 @@ class ShardConfig:
Turn on all optimization.
"""
# you can add all the optimization flag here
- self.enable_fused_normalization = True
+ try:
+ from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm # noqa
+
+ apex_avail = True
+ except ImportError:
+ apex_avail = False
+ warnings.warn("You set enable_all_optimization=True, but apex is not installed.")
+
+ self.enable_fused_normalization = apex_avail
self.enable_flash_attention = True
self.enable_jit_fused = True
# This can cause non-in-place param sharding when used without ZeRO.
diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py
index b132f47fd..b3991c4f0 100644
--- a/colossalai/shardformer/shard/shardformer.py
+++ b/colossalai/shardformer/shard/shardformer.py
@@ -26,7 +26,7 @@ class ShardFormer:
import colossalai
import torch
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig()
diff --git a/colossalai/tensor/d_tensor/README.md b/colossalai/tensor/d_tensor/README.md
index 3d862dddb..367db5ccd 100644
--- a/colossalai/tensor/d_tensor/README.md
+++ b/colossalai/tensor/d_tensor/README.md
@@ -69,7 +69,7 @@ import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor import DTensor, ShardingSpec
-colossalai.launch_from_torch(config={})
+colossalai.launch_from_torch()
# define your device mesh
# assume you have 4 GPUs
diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py
index 667a7b78e..c2cf73181 100644
--- a/colossalai/tensor/d_tensor/layout_converter.py
+++ b/colossalai/tensor/d_tensor/layout_converter.py
@@ -10,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.misc import LayoutException
+from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from .sharding_spec import ShardingSpec
@@ -607,8 +608,18 @@ class LayoutConverter(metaclass=SingletonMeta):
[3.],
[3.]])
"""
+
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
+
+ target_tensor = tensor
for comm_spec in comm_action_sequence:
- tensor = comm_spec.covert_spec_to_action(tensor)
- tensor.dist_layout = target_layout
- return tensor
+ target_tensor = comm_spec.covert_spec_to_action(target_tensor)
+ target_tensor.dist_layout = target_layout
+
+ # restore the padding information
+ if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor):
+ target_tensor = init_as_padded_tensor(
+ target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
+ )
+
+ return target_tensor
diff --git a/colossalai/tensor/padded_tensor/__init__.py b/colossalai/tensor/padded_tensor/__init__.py
new file mode 100644
index 000000000..353ff35f8
--- /dev/null
+++ b/colossalai/tensor/padded_tensor/__init__.py
@@ -0,0 +1,3 @@
+from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
+
+__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"]
diff --git a/colossalai/tensor/padded_tensor/api.py b/colossalai/tensor/padded_tensor/api.py
new file mode 100644
index 000000000..5b66c016b
--- /dev/null
+++ b/colossalai/tensor/padded_tensor/api.py
@@ -0,0 +1,128 @@
+import torch
+
+
+def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
+ """
+ Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be hijacked.
+
+ Returns:
+ torch.Tensor: The hijacked tensor.
+ """
+ ptensor._unpad_detach = ptensor.detach
+ ptensor._unpad_clone = ptensor.clone
+
+ def new_detach(self):
+ t_ = self._unpad_detach()
+ t_._padding_dim = self._padding_dim
+ t_._origin_length = self._origin_length
+ t_._current_length = self._current_length
+ return t_
+
+ def new_clone(self, *args, **kwargs):
+ t_ = self._unpad_clone(*args, **kwargs)
+ t_._padding_dim = self._padding_dim
+ t_._origin_length = self._origin_length
+ t_._current_length = self._current_length
+ return t_
+
+ # bind the new methods to the tensor
+ ptensor.detach = new_detach.__get__(ptensor)
+ ptensor.clone = new_clone.__get__(ptensor)
+ return ptensor
+
+
+def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
+ """
+ Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be hijacked.
+
+ Returns:
+ torch.Tensor: The hijacked tensor.
+ """
+ ptensor.detach = ptensor._unpad_detach
+ ptensor.clone = ptensor._unpad_clone
+
+ delattr(ptensor, "_unpad_detach")
+ delattr(ptensor, "_unpad_clone")
+
+ return ptensor
+
+
+def is_padded_tensor(tensor: torch.Tensor) -> bool:
+ """
+ Check whether the given tensor is a padding tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ bool: Whether the given tensor is a padding tensor.
+ """
+ return hasattr(tensor, "_padding_dim")
+
+
+def to_padded_tensor(
+ tensor: torch.Tensor,
+ current_length: int,
+ padding_dim: int,
+) -> torch.Tensor:
+ assert (
+ padding_dim < tensor.dim()
+ ), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}"
+
+ if is_padded_tensor(tensor):
+ return tensor
+
+ origin_length = tensor.shape[padding_dim]
+ padding_num = current_length - origin_length
+ padding_data = torch.zeros(
+ *tensor.shape[:padding_dim],
+ padding_num,
+ *tensor.shape[padding_dim + 1 :],
+ device=tensor.device,
+ dtype=tensor.dtype,
+ )
+ tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()
+
+ tensor._padding_dim = padding_dim
+ tensor._origin_length = origin_length
+ tensor._current_length = current_length
+
+ _hijack_detach_and_clone(tensor)
+
+ return tensor
+
+
+def to_unpadded_tensor(ptensor: torch.Tensor):
+ if not is_padded_tensor(ptensor):
+ return ptensor
+
+ unpad_slices = [slice(None)] * ptensor.dim()
+ unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length)
+ ptensor.data = ptensor.data[tuple(unpad_slices)]
+
+ delattr(ptensor, "_padding_dim")
+ delattr(ptensor, "_origin_length")
+ delattr(ptensor, "_current_length")
+
+ _hijack_back_detach_and_clone(ptensor)
+
+ return ptensor
+
+
+def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
+ if is_padded_tensor(tensor):
+ return tensor
+
+ tensor._padding_dim = padding_dim
+ tensor._origin_length = origin_length
+ tensor._current_length = current_length
+
+ _hijack_detach_and_clone(tensor)
+
+ return tensor
diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py
index e415b5fc3..bdf7b19f3 100644
--- a/colossalai/testing/comparison.py
+++ b/colossalai/testing/comparison.py
@@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
- dtype: {a.dtype} vs {b.dtype}",
+ dtype: {a.dtype} vs {b.dtype}",
)
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index bc6c9d088..b25de1d68 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -27,6 +27,12 @@ from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
)
+from colossalai.tensor.padded_tensor import (
+ init_as_padded_tensor,
+ is_padded_tensor,
+ to_padded_tensor,
+ to_unpadded_tensor,
+)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
@@ -460,6 +466,11 @@ class GeminiDDP(ModelWrapper):
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
)
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
+ if is_padded_tensor(tensor):
+ record_tensor = init_as_padded_tensor(
+ record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
+ )
+ record_tensor = to_unpadded_tensor(record_tensor)
assert tensor not in chunk_to_save_data
chunk_to_save_data[tensor] = record_tensor
@@ -520,6 +531,8 @@ class GeminiDDP(ModelWrapper):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
+ if is_padded_tensor(p_mapping[param]):
+ p_mapping[param] = to_unpadded_tensor(p_mapping[param])
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data
@@ -627,6 +640,7 @@ class GeminiDDP(ModelWrapper):
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
+
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@@ -647,6 +661,14 @@ class GeminiDDP(ModelWrapper):
if state_key in state_dict:
input_param = state_dict[state_key]
+ global_shape = dest_tensor.shape
+ if source_device_mesh is not None and source_sharding_spec is not None:
+ global_shape = get_global_shape(dest_tensor)
+
+ if is_padded_tensor(dest_tensor):
+ padding_dim = dest_tensor._padding_dim
+ input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim)
+
if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None:
@@ -818,6 +840,7 @@ class GeminiDDP(ModelWrapper):
for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor):
buffer.materialize()
+ for buffer in self.module.buffers():
buffer.data = buffer.to(get_accelerator().get_current_device())
if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision)
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 18367af59..ae02fe297 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -21,12 +21,19 @@ from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
+ get_global_shape,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor,
is_distributed_tensor,
)
+from colossalai.tensor.padded_tensor import (
+ init_as_padded_tensor,
+ is_padded_tensor,
+ to_padded_tensor,
+ to_unpadded_tensor,
+)
from colossalai.utils import disposable, is_ddp_ignored
from .chunk import Chunk, ChunkManager
@@ -106,7 +113,7 @@ class GeminiOptimizer(OptimizerWrapper):
max_norm: float = 0.0,
norm_type: float = 2.0,
tp_group: ProcessGroup = None,
- optimizer_params_info=None,
+ params_info=None,
verbose: bool = False,
**defaults: Any,
):
@@ -124,7 +131,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm
self.tp_group = tp_group
- self.optimizer_params_info = optimizer_params_info
+ self.params_info = params_info
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose
@@ -459,7 +466,7 @@ class GeminiOptimizer(OptimizerWrapper):
is_customized_distributed = is_customized_distributed_tensor(param)
shard_spec = get_sharding_spec(param) if is_dtensor else None
device_mesh = get_device_mesh(param) if is_dtensor else None
- global_shape = self.optimizer_params_info["id2shape"][param_id]
+ global_shape = self.params_info["id2shape"][param_id]
# If the chunk is kept gathered,
# the parameters are treated the same as that of those in strict DDP during training.
@@ -477,6 +484,7 @@ class GeminiOptimizer(OptimizerWrapper):
else:
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
if is_dtensor:
+ global_shape = get_global_shape(param)
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor(
state_tensor,
@@ -490,8 +498,13 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
-
- collected_states[state_name] = state_tensor.reshape(global_shape)
+ state_tensor = state_tensor.reshape(global_shape)
+ if is_padded_tensor(param):
+ state_tensor = init_as_padded_tensor(
+ state_tensor, param._current_length, param._origin_length, param._padding_dim
+ )
+ state_tensor = to_unpadded_tensor(state_tensor)
+ collected_states[state_name] = state_tensor
return collected_states
# Check whether the param with given id is managed by current process.
@@ -535,6 +548,7 @@ class GeminiOptimizer(OptimizerWrapper):
if state_tensor.numel() == param.numel():
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor:
+ global_shape = get_global_shape(param)
state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor(
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
@@ -545,6 +559,11 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
+ if is_padded_tensor(param):
+ state_tensor = init_as_padded_tensor(
+ state_tensor, param._current_length, param._origin_length, param._padding_dim
+ )
+ state_tensor = to_unpadded_tensor(state_tensor)
return collected_states
@@ -698,7 +717,7 @@ class GeminiOptimizer(OptimizerWrapper):
Load saved optimizer states into parameter with given id.
"""
- def cast(param, state_range, value, key=None):
+ def cast(param, state_range, value, global_shape, origin_shape, key=None):
"""
Make a copy of the needed segment of value and cast it to device of param.
"""
@@ -714,7 +733,14 @@ class GeminiOptimizer(OptimizerWrapper):
)
if is_dtensor:
- value = torch.reshape(value, global_shape)
+ global_shape = get_global_shape(real_param)
+
+ if is_padded_tensor(real_param):
+ value = torch.reshape(value, origin_shape)
+ padding_dim = real_param._padding_dim
+ value = to_padded_tensor(value, global_shape[padding_dim], padding_dim)
+
+ if is_dtensor:
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
elif is_customized_distributed:
value = torch.reshape(value, global_shape)
@@ -737,10 +763,11 @@ class GeminiOptimizer(OptimizerWrapper):
is_customized_distributed = is_customized_distributed_tensor(real_param)
shard_spec = get_sharding_spec(real_param) if is_dtensor else None
device_mesh = get_device_mesh(real_param) if is_dtensor else None
- global_shape = self.optimizer_params_info["id2shape"][param_id]
+ global_shape = self.params_info["id2shape"][param_id]
+ origin_shape = global_shape
for k, v in saved_states.items():
- updated_states[k] = cast(fake_param, state_range, v, k)
+ updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k)
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)
diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py
index f395fc60e..2ebc704f7 100644
--- a/colossalai/zero/low_level/bookkeeping/bucket_store.py
+++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py
@@ -11,7 +11,9 @@ from .base_store import BaseStore
class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
+ self.reset_all()
+ def reset_all(self) -> None:
# init
self.current_group_id = 0
self._num_elements_in_bucket = 0
diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py
index 73a1db5a0..6d4fcbb86 100644
--- a/colossalai/zero/low_level/bookkeeping/gradient_store.py
+++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py
@@ -82,6 +82,9 @@ class GradientStore(BaseStore):
"""
grad_list = []
+ # When using LoRa and the user sets multiple param_groups, it is possible that some param_groups have no parameters with gradients.
+ if group_id not in self._grads_of_params.keys():
+ return grad_list
for param_grads in self._grads_of_params[group_id].values():
grad_list.append(param_grads[self._working_index])
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index bbbaf13b5..345dfde73 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -40,7 +40,13 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
max_scale: float = 2**32,
) -> None:
super().__init__(
- initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
+ initial_scale,
+ min_scale,
+ growth_factor,
+ backoff_factor,
+ growth_interval,
+ hysteresis,
+ max_scale,
)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store
@@ -229,9 +235,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for param_group in self.optim.param_groups:
group_params = param_group["params"]
for param in group_params:
- assert (
- param.dtype == self._dtype
- ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
+ if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False:
+ assert (
+ param.dtype == self._dtype
+ ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
@@ -273,11 +280,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Backward Reduction Hook #
###########################
- def _grad_handler(self, param, group_id, grad):
+ def _grad_handler(self, group_id, param):
# if run with no_sync context, would not sync grad when backward
if self.require_grad_sync:
self._add_to_bucket(param, group_id)
- return grad
def _attach_reduction_hook(self):
# we iterate over the working params
@@ -286,7 +292,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
- param.register_hook(partial(self._grad_handler, param, group_id))
+ param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id))
#######################
# Reduction Functions #
@@ -415,7 +421,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
self._update_partitoned_grad(
- non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
+ non_moe_grad_in_bucket_current_rank,
+ recieved_grad,
+ group_id,
+ 1,
)
if len(moe_grad_list) > 0:
@@ -423,7 +432,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
- dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
+ dist.reduce_scatter(
+ recieved_grad,
+ flat_grads_list,
+ group=self.moe_extra_dp_pg,
+ )
param_slice = self._world_size // self.moe_extra_dp_pg_size
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
for split_recieved_grad in recieved_grad:
@@ -444,14 +457,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._add_grad(grad, self._world_size, group_id, param_id, rank)
def _update_partitoned_grad(
- self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
+ self,
+ origin_grad_list: List,
+ flat_grad: torch.Tensor,
+ group_id: int,
+ partition_num: int,
) -> None:
sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, partition_num, group_id, param_id)
- def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
+ def _add_grad(
+ self,
+ grad: torch.Tensor,
+ partition_num: int,
+ group_id: int,
+ param_id: int,
+ rank: int = 0,
+ ) -> None:
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
@@ -534,6 +558,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if param.grad is not None:
param.grad.detach()
param.grad.zero_()
+ self._bucket_store.reset_all()
####################
# Update Parameter #
@@ -655,14 +680,20 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(
- all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
+ all_splited_param,
+ splited_param.to(device).to(self._dtype),
+ group=self.moe_extra_dp_pg,
)
else:
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size)
]
- dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
+ dist.all_gather(
+ all_splited_param,
+ splited_param.to(device).to(self._dtype),
+ group=self.dp_pg,
+ )
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
@@ -685,7 +716,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor(
- [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
+ [float(total_norm)],
+ device=get_accelerator().get_current_device(),
+ dtype=torch.float,
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
@@ -698,10 +731,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.tensor(
- [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
+ [float(total_norm_exponentiated)],
+ device=get_accelerator().get_current_device(),
+ dtype=torch.float,
)
torch.distributed.all_reduce(
- total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
+ total_norm_exponentiated_cuda,
+ op=torch.distributed.ReduceOp.SUM,
+ group=self.dp_pg,
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
@@ -920,5 +957,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
if hasattr(self, "moe_master_to_working_map"):
- return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
+ return {
+ **self._param_store.master_to_working_param,
+ **self.moe_master_to_working_map,
+ }
return self._param_store.master_to_working_param
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 6d243a808..2e5437752 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,8 @@
## 新闻
+* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
+* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
@@ -51,7 +53,7 @@
并行训练样例展示
- - LLaMA 1/2
+ - LLaMA 1/2/3
- MoE
- GPT-3
- GPT-2
@@ -126,7 +128,7 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
[Open-Sora](https://github.com/hpcaitech/Open-Sora):全面开源类Sora模型参数和所有训练细节
[[代码]](https://github.com/hpcaitech/Open-Sora)
-[[博客]](https://hpc-ai.com/blog/open-sora-v1.0)
+[[博客]](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
[[模型权重]](https://huggingface.co/hpcai-tech/Open-Sora)
[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)
@@ -261,6 +263,14 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的
(返回顶端)
## 并行训练样例展示
+### LLaMA3
+
+
+
+
+- 700亿参数LLaMA3训练加速18%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama)
+
### LLaMA2
diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 0133dfd86..b27f9c811 100644
--- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -75,7 +75,7 @@ WARMUP_FRACTION = 0.1
we create a distributed environment.
```python
# Launch ColossalAI
-colossalai.launch_from_torch(config={}, seed=42)
+colossalai.launch_from_torch( seed=42)
coordinator = DistCoordinator()
```
prepare the dataset. You can use `plugin.prepare_dataloader` to generate a dataloader or customize your own dataloader.
diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index dfc2cd596..ac4169344 100644
--- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -71,7 +71,7 @@ PP_SIZE = 2
Create a distributed environment.
```python
# Launch ColossalAI
-colossalai.launch_from_torch(config={}, seed=SEEDå)
+colossalai.launch_from_torch( seed=SEEDå)
coordinator = DistCoordinator()
world_size = coordinator.world_size
```
diff --git a/docs/source/en/basics/booster_api.md b/docs/source/en/basics/booster_api.md
index 2c75dd9ac..a33be3b49 100644
--- a/docs/source/en/basics/booster_api.md
+++ b/docs/source/en/basics/booster_api.md
@@ -55,7 +55,7 @@ from colossalai.booster.plugin import TorchDDPPlugin
def train():
# launch colossalai
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost')
# create plugin and objects for training
plugin = TorchDDPPlugin()
diff --git a/docs/source/en/basics/launch_colossalai.md b/docs/source/en/basics/launch_colossalai.md
index 334757ea7..8a6028d6c 100644
--- a/docs/source/en/basics/launch_colossalai.md
+++ b/docs/source/en/basics/launch_colossalai.md
@@ -87,8 +87,7 @@ import colossalai
args = colossalai.get_default_parser().parse_args()
# launch distributed environment
-colossalai.launch(config=args.config,
- rank=args.rank,
+colossalai.launch(rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
@@ -106,20 +105,11 @@ First, we need to set the launch method in our code. As this is a wrapper of the
use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch
launcher and can be read from the environment variable directly.
-config.py
-```python
-BATCH_SIZE = 512
-LEARNING_RATE = 3e-3
-WEIGHT_DECAY = 0.3
-NUM_EPOCHS = 2
-```
train.py
```python
import colossalai
-colossalai.launch_from_torch(
- config="./config.py",
-)
+colossalai.launch_from_torch()
...
```
@@ -203,7 +193,6 @@ Do this in your training script:
import colossalai
colossalai.launch_from_slurm(
- config=,
host=args.host,
port=args.port
)
@@ -224,7 +213,6 @@ use them to start the distributed backend.
Do this in your train.py:
```python
colossalai.launch_from_openmpi(
- config=,
host=args.host,
port=args.port
)
@@ -238,3 +226,5 @@ mpirun --hostfile -np python train.py --host
diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md
index ea97dd92e..f1e47e9bb 100644
--- a/docs/source/en/features/gradient_accumulation_with_booster.md
+++ b/docs/source/en/features/gradient_accumulation_with_booster.md
@@ -45,7 +45,7 @@ We then need to initialize distributed environment. For demo purpose, we uses `l
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
-colossalai.launch_from_torch(config=dict())
+colossalai.launch_from_torch()
```
### Step 3. Create training components
diff --git a/docs/source/en/features/gradient_clipping_with_booster.md b/docs/source/en/features/gradient_clipping_with_booster.md
index 14eee67bc..9f9074e1d 100644
--- a/docs/source/en/features/gradient_clipping_with_booster.md
+++ b/docs/source/en/features/gradient_clipping_with_booster.md
@@ -61,7 +61,7 @@ We then need to initialize distributed environment. For demo purpose, we uses `l
for other initialization methods.
```python
-colossalai.launch_from_torch(config=dict())
+colossalai.launch_from_torch()
logger = get_dist_logger()
```
diff --git a/docs/source/en/features/lazy_init.md b/docs/source/en/features/lazy_init.md
index 160f68767..30b33b52f 100644
--- a/docs/source/en/features/lazy_init.md
+++ b/docs/source/en/features/lazy_init.md
@@ -29,7 +29,7 @@ from colossalai.booster.plugin import GeminiPlugin
from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining
-colossalai.launch({})
+colossalai.launch()
plugin = GeminiPlugin()
booster = Booster(plugin)
diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md
index 8e702a578..baaaacddd 100644
--- a/docs/source/en/features/mixed_precision_training_with_booster.md
+++ b/docs/source/en/features/mixed_precision_training_with_booster.md
@@ -20,10 +20,10 @@ In Colossal-AI, we have incorporated different implementations of mixed precisio
3. naive amp
| Colossal-AI | support tensor parallel | support pipeline parallel | fp16 extent |
-| -------------- | ----------------------- | ------------------------- | ---------------------------------------------------------------------------------------------------- |
-| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation |
-| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 |
-| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 |
+|----------------|-------------------------|---------------------------|------------------------------------------------------------------------------------------------------|
+| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation |
+| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 |
+| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 |
The first two rely on the original implementation of PyTorch (version 1.6 and above) and NVIDIA Apex.
The last method is similar to Apex O2 level.
@@ -164,7 +164,7 @@ parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
-colossalai.launch_from_torch(config=dict())
+colossalai.launch_from_torch()
```
diff --git a/docs/source/en/features/nvme_offload.md b/docs/source/en/features/nvme_offload.md
index 6ed6f2dee..343a1f67e 100644
--- a/docs/source/en/features/nvme_offload.md
+++ b/docs/source/en/features/nvme_offload.md
@@ -185,7 +185,7 @@ Then we can train GPT model with Gemini. The placement policy of Gemini should b
```python
def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
- colossalai.launch_from_torch({})
+ colossalai.launch_from_torch()
config = GPT2Config()
with ColoInitContext(device=torch.cuda.current_device()):
model = GPT2LMHeadModel(config)
diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md
index 672945ea2..68d310f5c 100644
--- a/docs/source/en/features/shardformer.md
+++ b/docs/source/en/features/shardformer.md
@@ -310,13 +310,6 @@ if dist.get_world_size() > 1:
2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
-3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
- ```python
- from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
- from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
- ```
- when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
-
## How Shardformer Works
### Main Idea
diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md
index 62be86488..f0c13830a 100644
--- a/docs/source/en/features/zero_with_chunk.md
+++ b/docs/source/en/features/zero_with_chunk.md
@@ -174,7 +174,7 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
# build criterion
criterion = GPTLMLoss()
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index cf7d19172..4d4ea8163 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -62,7 +62,7 @@ plugin = HybridParallelPlugin(
## 创建分布式环境.
```python
# Launch ColossalAI
-colossalai.launch_from_torch(config={}, seed=42)
+colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
```
## 定义GPT-2模型的训练组件
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index f32f6c367..c234a3c6e 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -70,7 +70,7 @@ PP_SIZE = 2
首先我们创建一个分布式环境
```python
# Launch ColossalAI
-colossalai.launch_from_torch(config={}, seed=SEEDå)
+colossalai.launch_from_torch(seed=SEEDå)
coordinator = DistCoordinator()
world_size = coordinator.world_size
```
diff --git a/docs/source/zh-Hans/basics/booster_api.md b/docs/source/zh-Hans/basics/booster_api.md
index bb100964d..a9357617d 100644
--- a/docs/source/zh-Hans/basics/booster_api.md
+++ b/docs/source/zh-Hans/basics/booster_api.md
@@ -60,7 +60,7 @@ from colossalai.booster.plugin import TorchDDPPlugin
def train():
# launch colossalai
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost')
# create plugin and objects for training
plugin = TorchDDPPlugin()
diff --git a/docs/source/zh-Hans/basics/launch_colossalai.md b/docs/source/zh-Hans/basics/launch_colossalai.md
index 39b09deae..a80d16717 100644
--- a/docs/source/zh-Hans/basics/launch_colossalai.md
+++ b/docs/source/zh-Hans/basics/launch_colossalai.md
@@ -74,8 +74,7 @@ import colossalai
args = colossalai.get_default_parser().parse_args()
# launch distributed environment
-colossalai.launch(config=args.config,
- rank=args.rank,
+colossalai.launch(rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
@@ -93,20 +92,11 @@ PyTorch自带的启动器需要在每个节点上都启动命令才能启动多
首先,我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装,那么我们自然而然应该使用`colossalai.launch_from_torch`。
分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。
-config.py
-```python
-BATCH_SIZE = 512
-LEARNING_RATE = 3e-3
-WEIGHT_DECAY = 0.3
-NUM_EPOCHS = 2
-```
train.py
```python
import colossalai
-colossalai.launch_from_torch(
- config="./config.py",
-)
+colossalai.launch_from_torch()
...
```
@@ -186,7 +176,6 @@ colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --e
import colossalai
colossalai.launch_from_slurm(
- config=,
host=args.host,
port=args.port
)
@@ -206,7 +195,6 @@ srun python train.py --host --port 29500
您可以在您的训练脚本中尝试以下操作。
```python
colossalai.launch_from_openmpi(
- config=,
host=args.host,
port=args.port
)
@@ -219,3 +207,5 @@ mpirun --hostfile -np python train.py --host
diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
index 824308f94..7ad8fb145 100644
--- a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
+++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
@@ -46,7 +46,7 @@ parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
-colossalai.launch_from_torch(config=dict())
+colossalai.launch_from_torch()
```
diff --git a/docs/source/zh-Hans/features/gradient_clipping_with_booster.md b/docs/source/zh-Hans/features/gradient_clipping_with_booster.md
index fdec09bf1..b000d4585 100644
--- a/docs/source/zh-Hans/features/gradient_clipping_with_booster.md
+++ b/docs/source/zh-Hans/features/gradient_clipping_with_booster.md
@@ -61,7 +61,7 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR
我们需要初始化分布式环境. 为了快速演示,我们使用`launch_from_torch`. 您可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md)
```python
-colossalai.launch_from_torch(config=dict())
+colossalai.launch_from_torch()
logger = get_dist_logger()
```
diff --git a/docs/source/zh-Hans/features/lazy_init.md b/docs/source/zh-Hans/features/lazy_init.md
index 137719c69..c9cc0e4ba 100644
--- a/docs/source/zh-Hans/features/lazy_init.md
+++ b/docs/source/zh-Hans/features/lazy_init.md
@@ -29,7 +29,7 @@ from colossalai.booster.plugin import GeminiPlugin
from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining
-colossalai.launch({})
+colossalai.launch()
plugin = GeminiPlugin()
booster = Booster(plugin)
diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
index 8e9f614a2..53d9013db 100644
--- a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
+++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
@@ -19,11 +19,11 @@ AMP 代表自动混合精度训练。
2. apex.amp
3. naive amp
-| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16 范围 |
-| -------------- | ------------ | ------------ | --------------------------------------------------------- |
-| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至 fp16 |
-| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 |
-| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至 fp16 |
+| Colossal-AI | 支持张量并行 | 支持流水并行 | fp16 范围 |
+|----------------|--------------|--------------|-------------------------------------------------------|
+| AMP_TYPE.TORCH | ✅ | ❌ | 在前向和反向传播期间,模型参数、激活和梯度向下转换至 fp16 |
+| AMP_TYPE.APEX | ❌ | ❌ | 更细粒度,我们可以选择 opt_level O0, O1, O2, O3 |
+| AMP_TYPE.NAIVE | ✅ | ✅ | 模型参数、前向和反向操作,全都向下转换至 fp16 |
前两个依赖于 PyTorch (1.6 及以上) 和 NVIDIA Apex 的原始实现。最后一种方法类似 Apex O2。在这些方法中,Apex-AMP 与张量并行不兼容。这是因为张量是以张量并行的方式在设备之间拆分的,因此,需要在不同的进程之间进行通信,以检查整个模型权重中是否出现 inf 或 nan。我们修改了 torch amp 实现,使其现在与张量并行兼容。
@@ -153,7 +153,7 @@ parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
-colossalai.launch_from_torch(config=dict())
+colossalai.launch_from_torch()
```
diff --git a/docs/source/zh-Hans/features/nvme_offload.md b/docs/source/zh-Hans/features/nvme_offload.md
index 1feb9dde5..f013e755d 100644
--- a/docs/source/zh-Hans/features/nvme_offload.md
+++ b/docs/source/zh-Hans/features/nvme_offload.md
@@ -175,7 +175,7 @@ Mem usage: 4968.016 MB
```python
def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
- colossalai.launch_from_torch({})
+ colossalai.launch_from_torch()
config = GPT2Config()
with ColoInitContext(device=torch.cuda.current_device()):
model = GPT2LMHeadModel(config)
diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md
index a7bcbd9f2..a42c7cc2e 100644
--- a/docs/source/zh-Hans/features/shardformer.md
+++ b/docs/source/zh-Hans/features/shardformer.md
@@ -303,13 +303,6 @@ if dist.get_world_size() > 1:
2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
-3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类:
- ```python
- from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
- from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
- ```
- 并且使用这些导入的类初始化模型。
-
## Shardformer的工作原理
diff --git a/docs/source/zh-Hans/features/zero_with_chunk.md b/docs/source/zh-Hans/features/zero_with_chunk.md
index c4f21c73c..4a4655d60 100644
--- a/docs/source/zh-Hans/features/zero_with_chunk.md
+++ b/docs/source/zh-Hans/features/zero_with_chunk.md
@@ -174,7 +174,7 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
# build criterion
criterion = GPTLMLoss()
diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py
index 40b11d649..48cde8239 100644
--- a/examples/community/roberta/pretraining/run_pretraining.py
+++ b/examples/community/roberta/pretraining/run_pretraining.py
@@ -35,12 +35,12 @@ def main():
if args.vscode_debug:
colossalai.launch(
- config={}, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend
+ rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend
)
args.local_rank = -1
args.log_interval = 1
else:
- colossalai.launch_from_torch(config={}) # args.colossal_config
+ colossalai.launch_from_torch() # args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info(
f"launch_from_torch, world size: {torch.distributed.get_world_size()} | "
diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py
index 8ce4dc3bb..64588e904 100644
--- a/examples/images/dreambooth/debug.py
+++ b/examples/images/dreambooth/debug.py
@@ -9,7 +9,7 @@ from colossalai.zero import ColoInitContext
path = "/data/scratch/diffuser/stable-diffusion-v1-4"
-colossalai.launch_from_torch(config={})
+colossalai.launch_from_torch()
with ColoInitContext(device="cpu"):
vae = AutoencoderKL.from_pretrained(
path,
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py
index cc2b2ebc7..2bacb3a04 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai.py
@@ -372,9 +372,9 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args):
if args.seed is None:
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
else:
- colossalai.launch_from_torch(config={}, seed=args.seed)
+ colossalai.launch_from_torch(seed=args.seed)
local_rank = dist.get_rank()
world_size = dist.get_world_size()
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
index 227488abe..c4ef2a34e 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
@@ -371,9 +371,9 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def main(args):
if args.seed is None:
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
else:
- colossalai.launch_from_torch(config={}, seed=args.seed)
+ colossalai.launch_from_torch(seed=args.seed)
local_rank = gpc.get_local_rank(ParallelMode.DATA)
world_size = gpc.get_world_size(ParallelMode.DATA)
diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py
index 5871bbf87..a53a85180 100644
--- a/examples/images/resnet/train.py
+++ b/examples/images/resnet/train.py
@@ -128,7 +128,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
coordinator = DistCoordinator()
# update the learning rate with linear scaling
diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py
index fdae9ee01..790bb2b74 100644
--- a/examples/images/vit/vit_benchmark.py
+++ b/examples/images/vit/vit_benchmark.py
@@ -46,7 +46,7 @@ def main():
args = parse_benchmark_args()
# Launch ColossalAI
- colossalai.launch_from_torch(config={}, seed=args.seed)
+ colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size
diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py
index 81009b370..a65f89171 100644
--- a/examples/images/vit/vit_train_demo.py
+++ b/examples/images/vit/vit_train_demo.py
@@ -137,7 +137,7 @@ def main():
args = parse_demo_args()
# Launch ColossalAI
- colossalai.launch_from_torch(config={}, seed=args.seed)
+ colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size
diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py
index a5b295a40..2d24d87ad 100644
--- a/examples/inference/benchmark_llama.py
+++ b/examples/inference/benchmark_llama.py
@@ -231,7 +231,7 @@ def benchmark_inference(args):
def hybrid_inference(rank, world_size, port, args):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
benchmark_inference(args)
diff --git a/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py b/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py
index 498282ba3..18fe76cf0 100644
--- a/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py
+++ b/examples/inference/benchmark_ops/benchmark_context_attn_unpad.py
@@ -4,7 +4,7 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from colossalai.inference.modeling.layers.attention import PagedAttention
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
+from tests.test_infer.test_kernels.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
try:
import triton # noqa
diff --git a/examples/inference/benchmark_ops/benchmark_decoding_attn.py b/examples/inference/benchmark_ops/benchmark_decoding_attn.py
index 1a80961a7..4471ddada 100644
--- a/examples/inference/benchmark_ops/benchmark_decoding_attn.py
+++ b/examples/inference/benchmark_ops/benchmark_decoding_attn.py
@@ -2,14 +2,14 @@ import torch
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.triton.kernel_utils import (
+from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
-from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
+from tests.test_infer.test_kernels.triton.test_decoding_attn import prepare_data
try:
import triton # noqa
diff --git a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py
index 35eae69b6..d90de6664 100644
--- a/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py
+++ b/examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py
@@ -3,7 +3,7 @@ import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.triton.kernel_utils import (
+from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
generate_caches_and_block_tables_vllm,
diff --git a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py
index 6a499ccf2..80939f5a1 100644
--- a/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py
+++ b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py
@@ -2,7 +2,7 @@ import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
-from tests.test_infer.test_ops.triton.kernel_utils import (
+from tests.test_infer.test_kernels.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
mock_alloc_single_token,
diff --git a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py
index 03f797308..0232cb90e 100644
--- a/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py
+++ b/examples/inference/benchmark_ops/benchmark_kv_cache_memcopy.py
@@ -4,8 +4,8 @@ from colossalai.inference.modeling.layers.attention import copy_to_cache
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
-from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
+from tests.test_infer.test_kernels.cuda.test_kv_cache_memcpy import prepare_data as prepare_data_new_kcache_layout
+from tests.test_infer.test_kernels.triton.test_kvcache_copy import prepare_data
try:
import triton # noqa
diff --git a/examples/inference/benchmark_ops/benchmark_xine_copy.py b/examples/inference/benchmark_ops/benchmark_xine_copy.py
index b15232b91..633ceb6f1 100644
--- a/examples/inference/benchmark_ops/benchmark_xine_copy.py
+++ b/examples/inference/benchmark_ops/benchmark_xine_copy.py
@@ -1,7 +1,7 @@
import torch
from colossalai.kernel.triton import get_xine_cache
-from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
+from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
try:
import triton # noqa
diff --git a/examples/language/bert/benchmark.py b/examples/language/bert/benchmark.py
index 10bd367fd..9270c1b0c 100644
--- a/examples/language/bert/benchmark.py
+++ b/examples/language/bert/benchmark.py
@@ -81,7 +81,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
- colossalai.launch_from_torch(config={}, seed=42)
+ colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size
diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py
index bd6c393a7..7e8c07fdc 100644
--- a/examples/language/bert/finetune.py
+++ b/examples/language/bert/finetune.py
@@ -202,7 +202,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
- colossalai.launch_from_torch(config={}, seed=42)
+ colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
lr = LEARNING_RATE * coordinator.world_size
diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
index b35112498..fbb3a151a 100644
--- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
+++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
@@ -94,8 +94,7 @@ def train_gpt(args):
def run(rank, world_size, port, args):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
train_gpt(args)
diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
index f3d35dd90..9a33c6598 100644
--- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
+++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
@@ -47,7 +47,7 @@ def get_data(batch_size, seq_len, vocab_size):
def main():
disable_existing_loggers()
- launch_from_torch(config={})
+ launch_from_torch()
logger = get_dist_logger()
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
if FP16:
diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py
index 78d090ba2..4911ff124 100644
--- a/examples/language/gpt/gemini/train_gpt_demo.py
+++ b/examples/language/gpt/gemini/train_gpt_demo.py
@@ -132,7 +132,7 @@ def main():
PROF_FLAG = False # The flag of profiling, False by default
disable_existing_loggers()
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
logger = get_dist_logger()
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
diff --git a/examples/language/gpt/hybridparallelism/benchmark.py b/examples/language/gpt/hybridparallelism/benchmark.py
index 1315deae6..8c236b524 100644
--- a/examples/language/gpt/hybridparallelism/benchmark.py
+++ b/examples/language/gpt/hybridparallelism/benchmark.py
@@ -67,7 +67,7 @@ def main():
parser.add_argument("--cpu_offload", action="store_true", help="Use gradient checkpointing")
args = parser.parse_args()
- colossalai.launch_from_torch({})
+ colossalai.launch_from_torch()
coordinator = DistCoordinator()
def empty_init():
diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py
index ef51f938d..e5dc882bc 100644
--- a/examples/language/gpt/hybridparallelism/data.py
+++ b/examples/language/gpt/hybridparallelism/data.py
@@ -62,6 +62,8 @@ class GLUEDataBuilder:
self.text_fields = self.task_text_field_map[task_name]
self.num_labels = self.glue_task_num_labels[task_name]
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
+ if not getattr(self.tokenizer, "pad_token", None):
+ self.tokenizer.pad_token = self.tokenizer._eos_token
self.setup()
def setup(self):
diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py
index 888f47aaa..32b2dfcc0 100644
--- a/examples/language/gpt/hybridparallelism/finetune.py
+++ b/examples/language/gpt/hybridparallelism/finetune.py
@@ -196,7 +196,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
- colossalai.launch_from_torch(config={}, seed=42)
+ colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size
diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py
index 565cf1e01..6b45bd33e 100644
--- a/examples/language/gpt/titans/train_gpt.py
+++ b/examples/language/gpt/titans/train_gpt.py
@@ -36,9 +36,9 @@ def main():
args = parser.parse_args()
disable_existing_loggers()
if args.from_torch:
- colossalai.launch_from_torch(config=args.config)
+ colossalai.launch_from_torch()
else:
- colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
+ colossalai.launch_from_slurm(host=args.host, port=29500, seed=42)
logger = get_dist_logger()
data_path = None if args.use_dummy_dataset else os.environ["DATA"]
diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py
index e10c4929c..f7d7cf864 100644
--- a/examples/language/grok-1/inference_tp.py
+++ b/examples/language/grok-1/inference_tp.py
@@ -16,7 +16,7 @@ if __name__ == "__main__":
parser = get_default_parser()
args = parser.parse_args()
start = time.time()
- colossalai.launch_from_torch({})
+ colossalai.launch_from_torch()
coordinator = DistCoordinator()
plugin = HybridParallelPlugin(
tp_size=coordinator.world_size,
diff --git a/examples/language/llama/README.md b/examples/language/llama/README.md
new file mode 100644
index 000000000..fa0c6dc07
--- /dev/null
+++ b/examples/language/llama/README.md
@@ -0,0 +1,127 @@
+# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models
+### LLaMA3
+
+
+
+
+- 70 billion parameter LLaMA3 model training accelerated by 18%
+
+### LLaMA2
+
+
+
+
+- 70 billion parameter LLaMA2 model training accelerated by 195%
+[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
+
+### LLaMA1
+
+
+
+
+- 65-billion-parameter large model pretraining accelerated by 38%
+[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
+
+## Usage
+
+> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA).
+
+### 1. Installation
+
+Please install the latest ColossalAI from source.
+
+```bash
+BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
+```
+
+Then install other dependencies.
+
+```bash
+pip install -r requirements.txt
+```
+
+### 4. Shell Script Examples
+
+For your convenience, we provide some shell scripts to run benchmark with various configurations.
+
+You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of:
+```bash
+colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
+benchmark.py --OTHER_CONFIGURATIONS
+```
+Here we will show an example of how to run training
+llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
+
+#### a. Running environment
+This experiment was performed on 4 computing nodes with 32 A800/H800 80GB GPUs in total for LLaMA-1 65B or LLaMA-2 70B. The nodes are
+connected with RDMA and GPUs within one node are fully connected with NVLink.
+
+#### b. Running command
+
+```bash
+cd scripts/benchmark_7B
+```
+
+First, put your host file (`hosts.txt`) in this directory with your real host ip or host name.
+
+Here is a sample `hosts.txt`:
+```text
+hostname1
+hostname2
+hostname3
+hostname4
+```
+
+Then add environment variables to script if needed.
+
+Finally, run the following command to start training:
+
+```bash
+bash gemini.sh
+```
+
+If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU.
+
+#### c. Results
+If you run the above command successfully, you will get the following results:
+`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
+
+
+## Reference
+```
+@article{bian2021colossal,
+ title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
+ author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
+ journal={arXiv preprint arXiv:2110.14883},
+ year={2021}
+}
+```
+
+```bibtex
+@software{openlm2023openllama,
+ author = {Geng, Xinyang and Liu, Hao},
+ title = {OpenLLaMA: An Open Reproduction of LLaMA},
+ month = May,
+ year = 2023,
+ url = {https://github.com/openlm-research/open_llama}
+}
+```
+
+```bibtex
+@software{together2023redpajama,
+ author = {Together Computer},
+ title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
+ month = April,
+ year = 2023,
+ url = {https://github.com/togethercomputer/RedPajama-Data}
+}
+```
+
+```bibtex
+@article{touvron2023llama,
+ title={Llama: Open and efficient foundation language models},
+ author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
+ journal={arXiv preprint arXiv:2302.13971},
+ year={2023}
+}
+```
diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama/benchmark.py
similarity index 80%
rename from examples/language/llama2/benchmark.py
rename to examples/language/llama/benchmark.py
index 832465490..5cc602181 100644
--- a/examples/language/llama2/benchmark.py
+++ b/examples/language/llama/benchmark.py
@@ -3,14 +3,13 @@ import resource
from contextlib import nullcontext
import torch
-from attn import replace_with_flash_attention
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from tqdm import tqdm
+from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
-from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai
from colossalai.accelerator import get_accelerator
@@ -19,9 +18,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
-from examples.language.data_utils import RandomDataset
-from examples.language.model_utils import format_numel_str, get_model_numel
-from examples.language.performance_evaluator import PerformanceEvaluator
+from colossalai.shardformer import PipelineGradientCheckpointConfig
# ==============================
# Constants
@@ -78,14 +75,27 @@ def main():
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
+ parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
args = parser.parse_args()
- colossalai.launch_from_torch({})
+ colossalai.launch_from_torch()
coordinator = DistCoordinator()
def empty_init():
pass
+ # ckpt config for LLaMA3-70B on 64 H100 GPUs
+ hybrid_kwargs = (
+ {
+ "gradient_checkpoint_config": PipelineGradientCheckpointConfig(
+ num_ckpt_layers_per_stage=[19, 19, 19, 13],
+ ),
+ "num_layers_per_stage": [19, 20, 20, 21],
+ }
+ if args.custom_ckpt
+ else {}
+ )
+
# ==============================
# Initialize Booster
# ==============================
@@ -98,6 +108,8 @@ def main():
offload_param_frac=args.offload_param_frac,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
+ enable_fused_normalization=torch.cuda.is_available(),
+ enable_flash_attention=args.xformers,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -106,26 +118,34 @@ def main():
warmup_non_model_data_ratio=args.warmup_ratio,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
+ enable_fused_normalization=torch.cuda.is_available(),
+ enable_flash_attention=args.xformers,
)
elif args.plugin == "fsdp":
if use_empty_init:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
- param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
+ param_dtype=torch.float16,
+ reduce_dtype=torch.float16,
+ buffer_dtype=torch.float16,
),
param_init_fn=empty_init(),
)
else:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
- param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
+ param_dtype=torch.float16,
+ reduce_dtype=torch.float16,
+ buffer_dtype=torch.float16,
)
)
elif args.plugin == "fsdp_cpu":
if use_empty_init:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
- param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
+ param_dtype=torch.float16,
+ reduce_dtype=torch.float16,
+ buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
@@ -133,7 +153,9 @@ def main():
else:
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
- param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
+ param_dtype=torch.float16,
+ reduce_dtype=torch.float16,
+ buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
)
@@ -141,12 +163,13 @@ def main():
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
- pp_style="interleaved",
zero_stage=args.zero,
- num_model_chunks=2,
enable_fused_normalization=torch.cuda.is_available(),
+ enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
+ dp_outside=False,
+ **hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
plugin = HybridParallelPlugin(
@@ -155,6 +178,7 @@ def main():
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(),
+ enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
@@ -167,9 +191,12 @@ def main():
# ==============================
# Initialize Dataset and Dataloader
# ==============================
- dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size
+ dp_size = getattr(plugin, "dp_size", coordinator.world_size)
- config = MODEL_CONFIGS[args.config]
+ if args.config in MODEL_CONFIGS:
+ config = MODEL_CONFIGS[args.config]
+ else:
+ config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
)
@@ -184,14 +211,17 @@ def main():
else nullcontext()
)
+ init_kwargs = {}
+ if config.model_type == "chatglm":
+ init_kwargs["empty_init"] = False
+
with init_ctx:
- model = LlamaForCausalLM(config)
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
-
- if args.xformers:
- replace_with_flash_attention(model)
+ if config.model_type == "chatglm":
+ model.transformer.encoder.gradient_checkpointing = True
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
diff --git a/examples/language/llama/data_utils.py b/examples/language/llama/data_utils.py
new file mode 120000
index 000000000..2da9822df
--- /dev/null
+++ b/examples/language/llama/data_utils.py
@@ -0,0 +1 @@
+../data_utils.py
\ No newline at end of file
diff --git a/examples/language/llama/model_utils.py b/examples/language/llama/model_utils.py
new file mode 120000
index 000000000..73c6818a8
--- /dev/null
+++ b/examples/language/llama/model_utils.py
@@ -0,0 +1 @@
+../model_utils.py
\ No newline at end of file
diff --git a/examples/language/llama/performance_evaluator.py b/examples/language/llama/performance_evaluator.py
new file mode 120000
index 000000000..f4736354b
--- /dev/null
+++ b/examples/language/llama/performance_evaluator.py
@@ -0,0 +1 @@
+../performance_evaluator.py
\ No newline at end of file
diff --git a/examples/language/llama2/requirements.txt b/examples/language/llama/requirements.txt
similarity index 53%
rename from examples/language/llama2/requirements.txt
rename to examples/language/llama/requirements.txt
index 6b475682d..438a4999a 100644
--- a/examples/language/llama2/requirements.txt
+++ b/examples/language/llama/requirements.txt
@@ -1,9 +1,8 @@
-colossalai>=0.3.2
+colossalai>=0.3.6
datasets
numpy
-torch>=1.12.0,<=2.0.0
tqdm
transformers
-flash-attn>=2.0.0,<=2.0.5
+flash-attn>=2.0.0
SentencePiece==0.1.99
tensorboard==2.14.0
diff --git a/examples/language/llama2/scripts/benchmark_70B/3d.sh b/examples/language/llama/scripts/benchmark_70B/3d.sh
similarity index 100%
rename from examples/language/llama2/scripts/benchmark_70B/3d.sh
rename to examples/language/llama/scripts/benchmark_70B/3d.sh
diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini.sh b/examples/language/llama/scripts/benchmark_70B/gemini.sh
similarity index 100%
rename from examples/language/llama2/scripts/benchmark_70B/gemini.sh
rename to examples/language/llama/scripts/benchmark_70B/gemini.sh
diff --git a/examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh b/examples/language/llama/scripts/benchmark_70B/gemini_auto.sh
similarity index 100%
rename from examples/language/llama2/scripts/benchmark_70B/gemini_auto.sh
rename to examples/language/llama/scripts/benchmark_70B/gemini_auto.sh
diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini.sh b/examples/language/llama/scripts/benchmark_7B/gemini.sh
similarity index 100%
rename from examples/language/llama2/scripts/benchmark_7B/gemini.sh
rename to examples/language/llama/scripts/benchmark_7B/gemini.sh
diff --git a/examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh b/examples/language/llama/scripts/benchmark_7B/gemini_auto.sh
similarity index 100%
rename from examples/language/llama2/scripts/benchmark_7B/gemini_auto.sh
rename to examples/language/llama/scripts/benchmark_7B/gemini_auto.sh
diff --git a/examples/language/llama2/test_ci.sh b/examples/language/llama/test_ci.sh
similarity index 100%
rename from examples/language/llama2/test_ci.sh
rename to examples/language/llama/test_ci.sh
diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md
deleted file mode 100644
index 068f15cbb..000000000
--- a/examples/language/llama2/README.md
+++ /dev/null
@@ -1,232 +0,0 @@
-# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models
-
-### LLaMA2
-
-
-
-
-- 70 billion parameter LLaMA2 model training accelerated by 195%
-[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
-
-### LLaMA1
-
-
-
-
-- 65-billion-parameter large model pretraining accelerated by 38%
-[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
-
-## Dataset
-
-Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed.
-
-A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample).
-
-RedPajama-Data-1T consists of seven data slices:
-
-| | RedPajama | LLaMA |
-|---------------|--------------|---------------|
-| CommonCrawl | 878 billion | 852 billion |
-| C4 | 175 billion | 190 billion |
-| Github | 59 billion | 100 billion |
-| Books | 26 billion | 25 billion |
-| ArXiv | 28 billion | 33 billion |
-| Wikipedia | 24 billion | 25 billion |
-| StackExchange | 20 billion | 27 billion |
-| Total | 1.2 trillion | 1.25 trillion |
-
-## Training
-
-We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps.
-
-| params | learning rate | batch size |
-|--------|---------------|------------|
-| 6.7B | 3.0e-4 | 4M |
-| 13.0B | 3.0e-4 | 4M |
-| 32.5B | 1.5e-4 | 4M |
-| 65.2B | 1.5e-4 | 4M |
-
-## Usage
-
-### 1. Installation
-
-Please install the latest ColossalAI from source.
-
-```bash
-BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
-```
-
-Then install other dependencies.
-
-```bash
-pip install -r requirements.txt
-```
-
-Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
-
-### 2. Download the dataset
-
-The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`.
-
-### 3. Command line arguments
-
-Yon can use colossalai run to launch multi-nodes training:
-```bash
-colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
-pretrain.py --OTHER_CONFIGURATIONS
-```
-
-Here is a sample hostfile:
-
-```text
-hostname1
-hostname2
-hostname3
-hostname4
-```
-
-Make sure master node can access all nodes (including itself) by ssh without password.
-
-Here is details about CLI arguments:
-
-- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
-- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
-- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
-- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
-- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
-- Learning rate: `--lr`. The default value is 3e-4.
-- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
-- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000.
-- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
-- Max length: `-l`, `--max_length`. The default value is 4096.
-- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
-- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
-- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
-- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
-- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
-- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
-- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
-
-
-### 4. Shell Script Examples
-
-For your convenience, we provide some shell scripts to run benchmark with various configurations.
-
-You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of:
-```bash
-colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
-benchmark.py --OTHER_CONFIGURATIONS
-```
-Here we will show an example of how to run training
-llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
-
-#### a. Running environment
-This experiment was performed on 4 computing nodes with 32 A800/H800 80GB GPUs in total for LLaMA-1 65B or LLaMA-2 70B. The nodes are
-connected with RDMA and GPUs within one node are fully connected with NVLink.
-
-#### b. Running command
-
-```bash
-cd scripts/benchmark_7B
-```
-
-First, put your host file (`hosts.txt`) in this directory with your real host ip or host name.
-
-Here is a sample `hosts.txt`:
-```text
-hostname1
-hostname2
-hostname3
-hostname4
-```
-
-Then add environment variables to script if needed.
-
-Finally, run the following command to start training:
-
-```bash
-bash gemini.sh
-```
-
-If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU.
-
-#### c. Results
-If you run the above command successfully, you will get the following results:
-`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
-
-
-## Reference
-```
-@article{bian2021colossal,
- title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
- author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
- journal={arXiv preprint arXiv:2110.14883},
- year={2021}
-}
-```
-
-```bibtex
-@software{openlm2023openllama,
- author = {Geng, Xinyang and Liu, Hao},
- title = {OpenLLaMA: An Open Reproduction of LLaMA},
- month = May,
- year = 2023,
- url = {https://github.com/openlm-research/open_llama}
-}
-```
-
-```bibtex
-@software{together2023redpajama,
- author = {Together Computer},
- title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
- month = April,
- year = 2023,
- url = {https://github.com/togethercomputer/RedPajama-Data}
-}
-```
-
-```bibtex
-@article{touvron2023llama,
- title={Llama: Open and efficient foundation language models},
- author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
- journal={arXiv preprint arXiv:2302.13971},
- year={2023}
-}
-```
-
-
-# Fine-tune Llama2
-
-We also provide a example to fine-tune llama2 in `finetune.py`,
-
-Make sure master node can access all nodes (including itself) by ssh without password.
-
-Here is details about CLI arguments:
-
-- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag.
-- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
-- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`.
-- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`.
-- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
-- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
-- Learning rate: `--lr`. The default value is 3e-4.
-- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
-- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
-- Max length: `-l`, `--max_length`. The default value is 4096.
-- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
-- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
-- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
-- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
-- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
-- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
-- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
-
-
-```shell
-torchrun --standalone --nproc_per_node 8 finetune.py \
- --plugin "hybrid_parallel" \
- --dataset "yizhongw/self_instruct" \
- --model_path "/path/llama" \
- --task_name "super_natural_instructions" \
- --save_dir "/path/output"
-```
diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py
deleted file mode 120000
index 4e95c7bfa..000000000
--- a/examples/language/llama2/attn.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
\ No newline at end of file
diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py
deleted file mode 100644
index 69b4ebe42..000000000
--- a/examples/language/llama2/finetune.py
+++ /dev/null
@@ -1,313 +0,0 @@
-import argparse
-import math
-import os
-import resource
-from contextlib import nullcontext
-from functools import partial
-from typing import Optional, Tuple
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from attn import replace_with_flash_attention
-from data_utils import load_json, prepare_dataloader, save_json
-from datasets import load_dataset
-from torch.optim import Optimizer
-from torch.optim.lr_scheduler import _LRScheduler
-from torch.utils.tensorboard import SummaryWriter
-from tqdm import tqdm
-from transformers.models.llama.configuration_llama import LlamaConfig
-from transformers.models.llama.modeling_llama import LlamaForCausalLM
-from transformers.models.llama.tokenization_llama import LlamaTokenizer
-
-import colossalai
-from colossalai.accelerator import get_accelerator
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.lazy import LazyInitContext
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-
-
-def get_model_numel(model: nn.Module) -> int:
- return sum(p.numel() for p in model.parameters())
-
-
-def format_numel_str(numel: int) -> str:
- B = 1024**3
- M = 1024**2
- K = 1024
- if numel >= B:
- return f"{numel / B:.2f} B"
- elif numel >= M:
- return f"{numel / M:.2f} M"
- elif numel >= K:
- return f"{numel / K:.2f} K"
- else:
- return f"{numel}"
-
-
-def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
- texts = [sample["prompt"] + sample["completion"] for sample in batch]
- data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
- data = {k: v.cuda() for k, v in data.items()}
- data["labels"] = data["input_ids"].clone()
- return data
-
-
-def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
- tensor = tensor.data
- tensor.div_(dist.get_world_size())
- return tensor
-
-
-def save(
- booster: Booster,
- model: nn.Module,
- optimizer: Optimizer,
- lr_scheduler: _LRScheduler,
- epoch: int,
- step: int,
- batch_size: int,
- coordinator: DistCoordinator,
- save_dir: str,
-):
- save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
- os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
-
- booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
- booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
- booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
- running_states = {
- "epoch": epoch,
- "step": step,
- "sample_start_index": step * batch_size,
- }
- if coordinator.is_master():
- save_json(running_states, os.path.join(save_dir, "running_states.json"))
-
-
-def load(
- booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
-) -> Tuple[int, int, int]:
- booster.load_model(model, os.path.join(load_dir, "model"))
- booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
- booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
- running_states = load_json(os.path.join(load_dir, "running_states.json"))
- return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
-
-
-def _criterion(outputs, inputs):
- return outputs.loss
-
-
-def main():
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune")
- parser.add_argument(
- "-p",
- "--plugin",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"],
- default="gemini",
- help="Choose which plugin to use",
- )
- parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path")
- parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run")
- parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs")
- parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size")
- parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
- parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
- parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
- parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval")
- parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory")
- parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint")
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping")
- parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory")
- parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention")
- args = parser.parse_args()
-
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch({})
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Booster
- # ==============================
- if args.plugin == "gemini":
- plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip
- )
- elif args.plugin == "hybrid_parallel":
- # modify the param accordingly, default configuration is for llama2-7b
- plugin = HybridParallelPlugin(
- tp_size=4,
- pp_size=2,
- num_microbatches=None,
- microbatch_size=1,
- enable_jit_fused=False,
- zero_stage=0,
- precision="fp32",
- initial_scale=1,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
-
- use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
- is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
- print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
-
- # ==============================
- # Initialize Tensorboard
- # ==============================
- if print_flag:
- os.makedirs(args.tensorboard_dir, exist_ok=True)
- writer = SummaryWriter(args.tensorboard_dir)
-
- # ==============================
- # Initialize Model, Optimizer and LR Scheduler
- # ==============================
-
- config = LlamaConfig.from_pretrained(args.model_path)
- # use lazy init when using GeminiPlugin
- init_ctx = (
- LazyInitContext(default_device=get_accelerator().get_current_device())
- if isinstance(plugin, GeminiPlugin)
- else nullcontext()
- )
-
- with init_ctx:
- model = LlamaForCausalLM(config)
-
- # ==============================
- # Initialize Tokenizer, Dataset and Dataloader
- # ==============================
- tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
- # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
- tokenizer.pad_token = tokenizer.unk_token
-
- dataset = load_dataset(args.dataset, args.task_name)
- train_ds = dataset["train"]
- dataloader = prepare_dataloader(
- train_ds,
- batch_size=args.batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length),
- )
-
- if args.grad_checkpoint:
- model.gradient_checkpointing_enable()
- if args.flash_attention:
- replace_with_flash_attention(model)
-
- model_numel = get_model_numel(model)
- coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
-
- optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
- total_step = args.num_epochs * len(dataloader)
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr
- )
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optimizer, _, dataloader, lr_scheduler = booster.boost(
- model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler
- )
- torch.set_default_dtype(torch.float)
-
- booster.load_model(model, args.model_path)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
- )
-
- # load checkpoint if specified
- start_epoch = 0
- start_step = 0
- sampler_start_idx = 0
- if args.load is not None:
- coordinator.print_on_master("Loading checkpoint")
- start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
- coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")
-
- num_steps_per_epoch = len(dataloader)
-
- # if resume training, set the sampler start index to the correct value
- dataloader.sampler.set_start_index(sampler_start_idx)
- for epoch in range(start_epoch, args.num_epochs):
- dataloader.sampler.set_epoch(epoch)
- step_nums = num_steps_per_epoch - start_step
- dataloader_iter = iter(dataloader)
-
- with tqdm(
- range(step_nums),
- desc=f"Epoch {epoch}",
- disable=not print_flag,
- total=num_steps_per_epoch,
- initial=start_step,
- ) as pbar:
- for step in pbar:
- if use_pipeline:
- outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
- loss = outputs["loss"]
- else:
- batch = next(dataloader_iter)
- outputs = model(**batch)
- loss = outputs[0]
- booster.backward(loss, optimizer)
-
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
-
- if not use_pipeline:
- all_reduce_mean(loss)
- if print_flag:
- pbar.set_postfix({"loss": loss.item()})
- writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
-
- if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
- coordinator.print_on_master(f"Saving checkpoint")
- save(
- booster,
- model,
- optimizer,
- lr_scheduler,
- epoch,
- step + 1,
- args.batch_size,
- coordinator,
- args.save_dir,
- )
- coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
- # the continue epochs are not resumed, so we need to reset the sampler start index and start step
- dataloader.sampler.set_start_index(0)
- start_step = 0
-
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py
deleted file mode 100644
index 970cd5290..000000000
--- a/examples/language/llama2/pretrain.py
+++ /dev/null
@@ -1,328 +0,0 @@
-import argparse
-import os
-import resource
-from contextlib import nullcontext
-from functools import partial
-from typing import Optional, Tuple
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from attn import replace_with_flash_attention
-from data_utils import load_json, prepare_dataloader, save_json
-from datasets import load_dataset
-from torch.optim import Optimizer
-from torch.optim.lr_scheduler import _LRScheduler
-from torch.utils.tensorboard import SummaryWriter
-from tqdm import tqdm
-from transformers.models.llama.configuration_llama import LlamaConfig
-from transformers.models.llama.modeling_llama import LlamaForCausalLM
-from transformers.models.llama.tokenization_llama import LlamaTokenizer
-
-import colossalai
-from colossalai.accelerator import get_accelerator
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.lazy import LazyInitContext
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-
-MODEL_CONFIGS = {
- "7b": LlamaConfig(max_position_embeddings=4096),
- "13b": LlamaConfig(
- hidden_size=5120,
- intermediate_size=13824,
- num_hidden_layers=40,
- num_attention_heads=40,
- max_position_embeddings=4096,
- ),
- "70b": LlamaConfig(
- hidden_size=8192,
- intermediate_size=28672,
- num_hidden_layers=80,
- num_attention_heads=64,
- max_position_embeddings=4096,
- num_key_value_heads=8,
- ),
-}
-
-
-def get_model_numel(model: nn.Module) -> int:
- return sum(p.numel() for p in model.parameters())
-
-
-def format_numel_str(numel: int) -> str:
- B = 1024**3
- M = 1024**2
- K = 1024
- if numel >= B:
- return f"{numel / B:.2f} B"
- elif numel >= M:
- return f"{numel / M:.2f} M"
- elif numel >= K:
- return f"{numel / K:.2f} K"
- else:
- return f"{numel}"
-
-
-def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
- texts = [sample["text"] for sample in batch]
- data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
- data = {k: v.cuda() for k, v in data.items()}
- data["labels"] = data["input_ids"].clone()
- return data
-
-
-def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
- tensor = tensor.data
- tensor.div_(dist.get_world_size())
- return tensor
-
-
-def save(
- booster: Booster,
- model: nn.Module,
- optimizer: Optimizer,
- lr_scheduler: _LRScheduler,
- epoch: int,
- step: int,
- batch_size: int,
- coordinator: DistCoordinator,
- save_dir: str,
-):
- save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}")
- os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
-
- booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
- booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
- booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
- running_states = {
- "epoch": epoch,
- "step": step,
- "sample_start_index": step * batch_size,
- }
- if coordinator.is_master():
- save_json(running_states, os.path.join(save_dir, "running_states.json"))
-
-
-def load(
- booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str
-) -> Tuple[int, int, int]:
- booster.load_model(model, os.path.join(load_dir, "model"))
- booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
- booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
- running_states = load_json(os.path.join(load_dir, "running_states.json"))
- return running_states["epoch"], running_states["step"], running_states["sample_start_index"]
-
-
-def _criterion(outputs, inputs):
- return outputs.loss
-
-
-def main():
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration")
- parser.add_argument(
- "-p",
- "--plugin",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"],
- default="gemini",
- help="Choose which plugin to use",
- )
- parser.add_argument(
- "-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path"
- )
- parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs")
- parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size")
- parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
- parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps")
- parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
- parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
- parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval")
- parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory")
- parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint")
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping")
- parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory")
- parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention")
- args = parser.parse_args()
-
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch({})
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Booster
- # ==============================
- if args.plugin == "gemini":
- plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip
- )
- elif args.plugin == "hybrid_parallel":
- # modify the param accordingly, default configuration is for llama2-7b
- plugin = HybridParallelPlugin(
- tp_size=4,
- pp_size=2,
- num_microbatches=None,
- microbatch_size=1,
- enable_jit_fused=False,
- zero_stage=0,
- precision=args.mixed_precision,
- initial_scale=1,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
-
- use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
- is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
- print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
-
- # ==============================
- # Initialize Tensorboard
- # ==============================
- if print_flag:
- os.makedirs(args.tensorboard_dir, exist_ok=True)
- writer = SummaryWriter(args.tensorboard_dir)
-
- # ==============================
- # Initialize Tokenizer, Dataset and Dataloader
- # ==============================
- tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
- # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
- tokenizer.pad_token = tokenizer.unk_token
-
- dataset = load_dataset(args.dataset)
- train_ds = dataset["train"]
- dataloader = prepare_dataloader(
- train_ds,
- batch_size=args.batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length),
- )
-
- # ==============================
- # Initialize Model, Optimizer and LR Scheduler
- # ==============================
- config = MODEL_CONFIGS[args.config]
- # use lazy init when using GeminiPlugin
- init_ctx = (
- LazyInitContext(default_device=get_accelerator().get_current_device())
- if isinstance(plugin, GeminiPlugin)
- else nullcontext()
- )
-
- with init_ctx:
- model = LlamaForCausalLM(config)
-
- if args.grad_checkpoint:
- model.gradient_checkpointing_enable()
- if args.flash_attention:
- replace_with_flash_attention(model)
-
- model_numel = get_model_numel(model)
- coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
-
- optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr
- )
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optimizer, _, dataloader, lr_scheduler = booster.boost(
- model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler
- )
- torch.set_default_dtype(torch.float)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
- )
-
- # load checkpoint if specified
- start_epoch = 0
- start_step = 0
- sampler_start_idx = 0
- if args.load is not None:
- coordinator.print_on_master("Loading checkpoint")
- start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
- coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}")
-
- num_steps_per_epoch = len(dataloader)
-
- # if resume training, set the sampler start index to the correct value
- dataloader.sampler.set_start_index(sampler_start_idx)
- for epoch in range(start_epoch, args.num_epochs):
- dataloader.sampler.set_epoch(epoch)
- dataloader_iter = iter(dataloader)
-
- with tqdm(
- range(start_step, num_steps_per_epoch),
- desc=f"Epoch {epoch}",
- disable=not print_flag,
- total=num_steps_per_epoch,
- initial=start_step,
- ) as pbar:
- for step in pbar:
- if use_pipeline:
- outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True)
- loss = outputs["loss"]
- else:
- batch = next(dataloader_iter)
- outputs = model(**batch)
- loss = outputs[0]
- booster.backward(loss, optimizer)
-
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
-
- if not use_pipeline:
- all_reduce_mean(loss)
- if print_flag:
- pbar.set_postfix({"loss": loss.item()})
- writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
-
- if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
- coordinator.print_on_master(f"Saving checkpoint")
- save(
- booster,
- model,
- optimizer,
- lr_scheduler,
- epoch,
- step + 1,
- args.batch_size,
- coordinator,
- args.save_dir,
- )
- coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}")
- # the continue epochs are not resumed, so we need to reset the sampler start index and start step
- dataloader.sampler.set_start_index(0)
- start_step = 0
-
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py
index a6d5f8bf2..22e0c790b 100644
--- a/examples/language/openmoe/benchmark/benchmark_cai.py
+++ b/examples/language/openmoe/benchmark/benchmark_cai.py
@@ -146,7 +146,7 @@ def main():
args = parse_args()
# Launch ColossalAI
- colossalai.launch_from_torch(config={}, seed=args.seed)
+ colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py
index 92f4e066a..40f072f13 100644
--- a/examples/language/openmoe/train.py
+++ b/examples/language/openmoe/train.py
@@ -207,7 +207,7 @@ def main():
args = parse_args()
# Launch ColossalAI
- colossalai.launch_from_torch(config={}, seed=args.seed)
+ colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
test_mode = args.model_name == "test"
diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py
index d16c9fdf9..c2883d96c 100755
--- a/examples/language/opt/opt_benchmark.py
+++ b/examples/language/opt/opt_benchmark.py
@@ -46,7 +46,7 @@ def main():
args = parse_benchmark_args()
# Launch ColossalAI
- colossalai.launch_from_torch(config={}, seed=args.seed)
+ colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size
diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py
index 05336bec4..b5b50305c 100644
--- a/examples/language/opt/opt_train_demo.py
+++ b/examples/language/opt/opt_train_demo.py
@@ -64,7 +64,7 @@ def main():
args = parse_demo_args()
# Launch ColossalAI
- colossalai.launch_from_torch(config={}, seed=args.seed)
+ colossalai.launch_from_torch(seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size
diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py
index 4fac7b507..76a86600b 100644
--- a/examples/language/palm/train.py
+++ b/examples/language/palm/train.py
@@ -102,7 +102,7 @@ args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
-colossalai.launch_from_torch(config={})
+colossalai.launch_from_torch()
logger = get_dist_logger()
diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py
index 29101ce08..b7a3f4320 100644
--- a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py
+++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py
@@ -20,7 +20,7 @@ def _benchmark(rank, world_size, port):
only result in minor performance drop. So at last we might be able to find better training batch size for our
model (combine with large batch training optimizer such as LAMB).
"""
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = tm.resnet152()
gm = symbolic_trace(model)
raw_graph = deepcopy(gm.graph)
diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py
index cd03a9179..81ef7ca03 100644
--- a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py
+++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py
@@ -17,7 +17,7 @@ def _benchmark(rank, world_size, port, args):
The benchmark will sample in a range of memory budget for each model and output the benchmark summary and
data visualization of peak memory vs. budget memory and relative step time vs. peak memory.
"""
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
if args.model == "resnet50":
model = tm.resnet50()
data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224))
diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py
index a4733126f..2b388fe36 100644
--- a/examples/tutorial/new_api/cifar_resnet/train.py
+++ b/examples/tutorial/new_api/cifar_resnet/train.py
@@ -128,7 +128,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
coordinator = DistCoordinator()
# update the learning rate with linear scaling
diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py
index ec6c852b5..84245d487 100644
--- a/examples/tutorial/new_api/cifar_vit/train.py
+++ b/examples/tutorial/new_api/cifar_vit/train.py
@@ -148,7 +148,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
- colossalai.launch_from_torch(config={})
+ colossalai.launch_from_torch()
coordinator = DistCoordinator()
# update the learning rate with linear scaling
diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py
index e97c9017f..624783a79 100644
--- a/examples/tutorial/new_api/glue_bert/finetune.py
+++ b/examples/tutorial/new_api/glue_bert/finetune.py
@@ -125,7 +125,7 @@ def main():
# ==============================
# Launch Distributed Environment
# ==============================
- colossalai.launch_from_torch(config={}, seed=42)
+ colossalai.launch_from_torch(seed=42)
coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size
diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py
index ae8a0f4a0..cb62f77e1 100644
--- a/examples/tutorial/opt/opt/run_clm.py
+++ b/examples/tutorial/opt/opt/run_clm.py
@@ -289,7 +289,7 @@ class DummyDataloader:
def main():
args = parse_args()
disable_existing_loggers()
- colossalai.legacy.launch_from_torch(config=dict())
+ colossalai.legacy.launch_from_torch()
logger = get_dist_logger()
is_main_process = dist.get_rank() == 0
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index a9d8b2363..bb97a2a3a 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -5,10 +5,9 @@ pytest
coverage==7.2.3
git+https://github.com/hpcaitech/pytest-testmon
torchvision
-transformers==4.33.0
timm
titans
-torchaudio
+torchaudio>=0.13.1
torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes.
torchrec==0.2.0
contexttimer
@@ -21,4 +20,5 @@ flash_attn
datasets
pydantic
ray
+peft>=0.7.1
#auto-gptq now not support torch1.12
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 7fac7f204..8ab13c0ad 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -8,7 +8,7 @@ click
fabric
contexttimer
ninja
-torch>=1.12
+torch>=2.1.0
safetensors
einops
pydantic
@@ -16,4 +16,6 @@ ray
sentencepiece
google
protobuf
-ordered-set
+transformers==4.36.2
+peft>=0.7.1
+bitsandbytes>=0.39.0
diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py
index 0b178d58c..f443553bb 100644
--- a/tests/kit/model_zoo/transformers/chatglm2.py
+++ b/tests/kit/model_zoo/transformers/chatglm2.py
@@ -1,7 +1,6 @@
import torch
-
-from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
-from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
+from torch.nn import init
+from transformers import AutoConfig, AutoModelForCausalLM
from ..registry import ModelAttribute, model_zoo
@@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
)
loss_fn = lambda x: x["loss"]
-config = ChatGLMConfig(
+config = AutoConfig.from_pretrained(
+ "THUDM/chatglm2-6b",
+ trust_remote_code=True,
num_layers=2,
padded_vocab_size=65024,
hidden_size=64,
+ ffn_hidden_size=214,
num_attention_heads=8,
kv_channels=16,
rmsnorm=True,
original_rope=True,
use_cache=True,
+ multi_query_attention=False,
torch_dtype=torch.float32,
)
-infer_config = ChatGLMConfig(
+
+infer_config = AutoConfig.from_pretrained(
+ "THUDM/chatglm2-6b",
+ trust_remote_code=True,
num_layers=2,
padded_vocab_size=65024,
hidden_size=128,
@@ -60,18 +66,18 @@ infer_config = ChatGLMConfig(
torch_dtype=torch.float32,
)
-model_zoo.register(
- name="transformers_chatglm",
- model_fn=lambda: ChatGLMModel(config, empty_init=False),
- data_gen_fn=data_gen,
- output_transform_fn=output_transform_fn,
- loss_fn=loss_fn_for_chatglm_model,
- model_attribute=ModelAttribute(has_control_flow=True),
-)
+
+def init_chatglm():
+ model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
+ for m in model.modules():
+ if m.__class__.__name__ == "RMSNorm":
+ init.ones_(m.weight)
+ return model
+
model_zoo.register(
name="transformers_chatglm_for_conditional_generation",
- model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
+ model_fn=init_chatglm,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py
index 58b5b0487..61fa56050 100644
--- a/tests/kit/model_zoo/transformers/llama.py
+++ b/tests/kit/model_zoo/transformers/llama.py
@@ -64,7 +64,6 @@ if HAS_LLAMA:
intermediate_size=64,
num_attention_heads=4,
max_position_embeddings=128,
- num_labels=16,
)
if hasattr(config, "pad_token_id"):
diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py
index 37f875857..ae5a97002 100644
--- a/tests/kit/model_zoo/transformers/mistral.py
+++ b/tests/kit/model_zoo/transformers/mistral.py
@@ -52,6 +52,9 @@ config = MistralConfig(
hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
)
+if hasattr(config, "pad_token_id"):
+ config.pad_token_id = config.eos_token_id
+
model_zoo.register(
name="transformers_mistral",
model_fn=lambda: transformers.MistralModel(config),
diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
index 03bba8e64..14bc7aa57 100644
--- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
+++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
@@ -27,7 +27,7 @@ except:
def _run_C_solver_consistency_test(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M()
diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
index c46f57f75..19d526524 100644
--- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
+++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
@@ -75,7 +75,7 @@ def check_backward_consistency(
def _run_ckpt_solver(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
@@ -111,7 +111,7 @@ def test_ckpt_solver():
def _run_ckpt_solver_torch11(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py
index 373ba28b8..3db7a1925 100644
--- a/tests/test_auto_parallel/test_offload/test_perf.py
+++ b/tests/test_auto_parallel/test_offload/test_perf.py
@@ -141,8 +141,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_fwd_bwd()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
index c41c66745..f39f09d54 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
@@ -42,7 +42,7 @@ class ConvModel(torch.nn.Module):
def check_linear_module(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = LinearModel(4, 8).cuda()
input = torch.rand(4, 4).cuda()
output_compare = model(input)
@@ -59,7 +59,7 @@ def check_linear_module(rank, world_size, port):
def check_conv_module(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = ConvModel(3, 6, 2).cuda()
input = torch.rand(4, 3, 64, 64).cuda()
output_compare = model(input)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
index c800f54da..f2b966b10 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
@@ -39,7 +39,7 @@ class GPT2MLPWithCkpt(nn.Module):
def check_act_ckpt(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
torch.rand(1, 64, HIDDEN_SIZE)
input_sample = {
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
index e8f175326..202f3e3bf 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
@@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
def check_compatibility_with_ddp(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = MLP(4).cuda()
if rank in [0, 1]:
input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
index d57717326..18de92e2a 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
@@ -34,7 +34,7 @@ class MLP(torch.nn.Module):
def check_auto_parallel_with_gemini(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = MLP(4).half().cuda()
if rank in [0, 1]:
input = torch.arange(0, 16).reshape(4, 4).half().cuda()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
index 24968e670..25c5d4ef1 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
@@ -73,7 +73,7 @@ def _check_module_grad(
def check_attention_layer(rank, model_cls, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py
index ba9e28214..d2f3e3724 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py
@@ -31,7 +31,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda()
input = torch.rand(32, 1024).cuda()
input.requires_grad = True
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py
index 455581545..5495282bc 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py
@@ -31,7 +31,7 @@ def _conv_module_mem_test(rank, world_size, port, bias):
port: port for initializing process group
"""
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda()
input = torch.rand(4, 4, 64, 64).cuda()
input.requires_grad = True
@@ -72,7 +72,7 @@ def _conv_function_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = ConvFunctionModule().cuda()
input = torch.rand(4, 4, 64, 64).cuda()
input.requires_grad = True
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py
index 639870c89..4958bad6b 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py
@@ -30,7 +30,7 @@ def _linear_module_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda()
input = torch.rand(8, 8, 16, 64).cuda()
input.requires_grad = True
@@ -68,7 +68,7 @@ def _linear_function_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = MyModule().cuda()
input = torch.rand(8, 8, 16, 64).cuda()
input.requires_grad = True
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py
index ed809a758..a0b81edab 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py
@@ -25,7 +25,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.BatchNorm2d(128)).cuda()
input = torch.rand(4, 128, 64, 64).cuda()
input.requires_grad = True
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py
index bd1deb40c..92d91383e 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py
@@ -21,7 +21,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda()
input = torch.rand(4, 128, 64, 64).cuda()
input.requires_grad = True
@@ -62,7 +62,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
port: port for initializing process group
"""
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda()
input = torch.rand(4, 128, 64, 64).cuda()
input.requires_grad = True
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
index 73a15f3ba..a8d2fbdfb 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
@@ -40,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = module(using_kwargs).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@@ -150,7 +150,7 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg
def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (1, 4)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py
index 26f9c4ab1..60eadeff9 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py
@@ -40,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
if model_cls == AddmmModel:
model = AddmmModel().cuda()
else:
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
index 86df7237a..e52cf28ab 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
@@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
def check_bn_module_handler(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.BatchNorm2d(16)).cuda()
physical_mesh_id = torch.arange(0, 4)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py
index e06625e1c..5982227b6 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py
@@ -34,7 +34,7 @@ class LinearModule(torch.nn.Module):
def check_linear_module_handler(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda()
physical_mesh_id = torch.arange(0, 4)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
index 690f0c123..c45e3e014 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
@@ -30,7 +30,7 @@ class LinearModule(torch.nn.Module):
def check_linear_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = LinearModule(16, 32, bias=bias).cuda()
physical_mesh_id = torch.arange(0, 4)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
index 5b2e2ab49..ad0d6d18c 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
@@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
class BinaryElementwiseOpModel(nn.Module):
def __init__(self, op):
@@ -145,7 +145,7 @@ class BEOpModelWithIntConst(nn.Module):
def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
index 29df12832..ac54f1230 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
@@ -26,7 +26,7 @@ class BMMTorchFunctionModule(nn.Module):
def check_2d_device_mesh(rank, module, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = module().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@@ -121,7 +121,7 @@ def check_2d_device_mesh(rank, module, world_size, port):
def check_1d_device_mesh(rank, module, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = module().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (1, 4)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
index 8a37dd925..407216f46 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
@@ -16,7 +16,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
def check_conv_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@@ -153,7 +153,7 @@ class ConvModel(nn.Module):
def check_conv_function_handler(rank, world_size, port, bias):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = ConvModel().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py
index 9ac6ba95d..f9a5b40a0 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py
@@ -33,7 +33,7 @@ class EmbeddingModule(nn.Module):
def check_embedding_module_handler(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@@ -150,7 +150,7 @@ class EmbeddingFunction(nn.Module):
def check_embedding_function_handler(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = EmbeddingFunction().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
index cf802a228..eb8e8ed3e 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
@@ -31,7 +31,7 @@ class GetItemFromTensorModel(nn.Module):
def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = GetItemFromTensorModel(getitem_index=getitem_index)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py
index 59a66bc6a..45aae2ea9 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py
@@ -17,7 +17,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
def check_ln_module_handler(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.LayerNorm(16)).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
index da88b735f..ddabdb700 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
@@ -23,7 +23,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
def check_linear_module_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@@ -171,7 +171,7 @@ class LinearModel(nn.Module):
def check_linear_function_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = LinearModel().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
index 958dc288f..09ad2ae32 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
@@ -51,7 +51,7 @@ class LinearReshapeModel(nn.Module):
def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
if call_function == torch.permute:
reshape_dims = reshape_dims[0]
elif call_function == torch.transpose:
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
index 1a99c32eb..88f34ff10 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
@@ -29,7 +29,7 @@ class LinearSplitModel(nn.Module):
def check_split_handler(rank, world_size, port, softmax_dim, model_cls):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = model_cls(softmax_dim=softmax_dim).cuda()
input = torch.rand(8, 16, 64, 32).to("cuda")
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
index 0318023c8..225a729ef 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
@@ -42,7 +42,7 @@ class LinearSplitModel(nn.Module):
def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
if model_cls.__name__ == "ConvSplitModel":
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py
index cbd3e4704..a79cfdf6f 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py
@@ -32,7 +32,7 @@ class LinearSumModel(nn.Module):
def check_sum_handler(rank, world_size, port, sum_dims, keepdim):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
index 466168c79..de483c997 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
@@ -41,7 +41,7 @@ class LinearViewModel(nn.Module):
def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
model = model_cls(tgt_shape).cuda()
if model_cls.__name__ == "ConvViewModel":
diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py
index 3aefb3797..f6d6e8303 100644
--- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py
+++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py
@@ -9,7 +9,7 @@ from tests.kit.model_zoo import model_zoo
def run_torch_amp(rank, world_size, port):
# init dist env
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
sub_model_zoo = model_zoo.get_sub_registry("timm")
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items():
# dlrm_interactionarch has not parameters, so skip
diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py
index 52cb8c46e..e57cadfd8 100644
--- a/tests/test_booster/test_plugin/test_3d_plugin.py
+++ b/tests/test_booster/test_plugin/test_3d_plugin.py
@@ -265,7 +265,7 @@ def run_grad_acc_test(test_args):
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_3d_plugin(early_stop=early_stop)
run_grad_acc_test()
diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py
index 0ac9d0f6d..a2a4a0c07 100644
--- a/tests/test_booster/test_plugin/test_dp_plugin_base.py
+++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterator, List, Tuple, Union
+from typing import Callable, Dict, Iterator, List, Tuple, Union
import torch
import torch.distributed as dist
@@ -51,6 +51,12 @@ class DPPluginWrapper(DPPluginBase):
def no_sync(self, model: nn.Module) -> Iterator[None]:
pass
+ def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
+ pass
+
+ def support_lora(self) -> bool:
+ pass
+
def check_dataloader_sharding():
plugin = DPPluginWrapper()
@@ -79,7 +85,7 @@ def check_dataloader_sharding():
def run_dist(rank, world_size, port):
# init dist env
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_dataloader_sharding()
diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py
index 892144772..b2790c0e7 100644
--- a/tests/test_booster/test_plugin/test_gemini_plugin.py
+++ b/tests/test_booster/test_plugin/test_gemini_plugin.py
@@ -161,7 +161,7 @@ def check_gemini_plugin(
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_gemini_plugin(early_stop=early_stop)
diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
index 861fa0131..4908b2d4f 100644
--- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
@@ -2,6 +2,7 @@ from typing import Optional
import torch
import torch.distributed as dist
+from peft import LoraConfig
from torch.optim import Adam
import colossalai
@@ -22,13 +23,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
@clear_cache_before_run()
-def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
+def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
device = get_accelerator().get_current_device()
try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = Adam(model.parameters(), lr=1e-3)
+
+ if lora_config is not None:
+ model = booster.enable_lora(model, lora_config=lora_config)
+
criterion = lambda x: x.mean()
data = data_gen_fn()
@@ -48,6 +53,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
except Exception as e:
return repr(e)
+ # raise e
@parameterize("stage", [2])
@@ -91,10 +97,42 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
+@parameterize("stage", [2])
+@parameterize("model_name", ["transformers_llama"])
+def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
+ passed_models = []
+ failed_info = {} # (model_name, error) pair
+
+ sub_model_zoo = model_zoo.get_sub_registry(model_name)
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ task_type = None
+ if name == "transformers_llama_for_casual_lm":
+ task_type = "CAUSAL_LM"
+ if name == "transformers_llama_for_sequence_classification":
+ task_type = "SEQ_CLS"
+ lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
+ err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config)
+
+ torch.cuda.empty_cache()
+
+ if err is None:
+ passed_models.append(name)
+ else:
+ failed_info[name] = err
+ if early_stop:
+ break
+
+ if dist.get_rank() == 0:
+ print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
+ print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
+ assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
+
+
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_low_level_zero_plugin(early_stop=early_stop)
+ check_low_level_zero_lora(early_stop=early_stop)
@rerun_if_address_is_in_use()
diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py
index e785843fb..052782047 100644
--- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py
+++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py
@@ -109,7 +109,7 @@ def check_torch_ddp_no_sync():
def run_dist(rank, world_size, port):
# init dist env
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_torch_ddp_plugin()
check_torch_ddp_no_sync()
diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
index f69807046..90e98f325 100644
--- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
@@ -73,7 +73,7 @@ def check_torch_fsdp_plugin():
def run_dist(rank, world_size, port):
# init dist env
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_torch_fsdp_plugin()
diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
index ac6f8caef..ade927e6e 100644
--- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
@@ -173,8 +173,7 @@ def exam_lazy_from_pretrained():
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
exam_state_dict_with_origin()
exam_lazy_from_pretrained()
diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py
index 44a000113..cd313c240 100644
--- a/tests/test_checkpoint_io/test_gemini_torch_compability.py
+++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py
@@ -163,8 +163,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_torch_load_from_gemini()
exam_gemini_load_from_torch()
diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
index d8a625b98..1cf94433d 100644
--- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
@@ -81,8 +81,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
optimizer.backward(loss)
optimizer.step()
- for group in optimizer.param_groups:
- group["lr"] = 0.1
+ optimizer.zero_grad()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
@@ -133,8 +132,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
index e7f44f97e..119e42e31 100644
--- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
@@ -1,5 +1,9 @@
+from copy import deepcopy
+from typing import Optional
+
import torch
import torch.distributed as dist
+from peft import LoraConfig
from torchvision.models import resnet18
from utils import shared_tempdir
@@ -15,6 +19,7 @@ from colossalai.testing import (
spawn,
)
from colossalai.zero import LowLevelZeroOptimizer
+from tests.kit.model_zoo import model_zoo
# stage 1 and 2 process the optimizer/mode the same way
@@ -69,9 +74,107 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
torch.cuda.empty_cache()
+def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
+ try:
+ plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload)
+ new_plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload)
+ booster = Booster(plugin=plugin)
+ new_booster = Booster(plugin=new_plugin)
+ model = model_fn()
+ optimizer = HybridAdam(model.parameters(), lr=1e-3)
+ new_model = deepcopy(model)
+ new_optimizer = HybridAdam(new_model.parameters(), lr=1e-3)
+ model = booster.enable_lora(model, lora_config=lora_config)
+ criterion = lambda x: x.mean()
+ data = data_gen_fn()
+
+ data = {
+ k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
+ }
+
+ model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
+
+ output = model(**data)
+ output = output_transform_fn(output)
+ output_key = list(output.keys())[0]
+ loss = criterion(output[output_key])
+
+ booster.backward(loss, optimizer)
+ optimizer.step()
+
+ with shared_tempdir() as tempdir:
+ model_ckpt_path = f"{tempdir}/model"
+ optimizer_ckpt_path = f"{tempdir}/optimizer"
+
+ booster.save_lora_as_pretrained(model, model_ckpt_path)
+ booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
+ new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
+ new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
+ check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
+
+ # check master weight
+ assert isinstance(new_optimizer, LowLevelZeroOptimizer)
+ working_param_id_set = set(id(p) for p in new_model.parameters())
+ for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
+ assert p_id in working_param_id_set
+ working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
+ padding = new_optimizer._param_store.get_param_padding_size(working_param)
+ padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
+ working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
+ assert torch.equal(
+ working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
+ )
+
+ new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
+ check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
+
+ except Exception as e:
+ # return repr(e)
+ raise e
+
+
+@clear_cache_before_run()
+@parameterize("stage", [2])
+@parameterize("shard", [True, False])
+@parameterize("offload", [False, True])
+@parameterize("model_name", ["transformers_llama"])
+def check_low_level_zero_lora_checkpointIO(
+ stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
+):
+ passed_models = []
+ failed_info = {} # (model_name, error) pair
+
+ sub_model_zoo = model_zoo.get_sub_registry(model_name)
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ if name != "transformers_llama":
+ continue
+ task_type = None
+ if name == "transformers_llama_for_casual_lm":
+ task_type = "CAUSAL_LM"
+ if name == "transformers_llama_for_sequence_classification":
+ task_type = "SEQ_CLS"
+ lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
+ err = run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config)
+
+ torch.cuda.empty_cache()
+
+ if err is None:
+ passed_models.append(name)
+ else:
+ failed_info[name] = err
+ if early_stop:
+ break
+
+ if dist.get_rank() == 0:
+ print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
+ print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
+ assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
+
+
def run_dist(rank, world_size, port):
- colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_low_level_zero_checkpointIO()
+ check_low_level_zero_lora_checkpointIO()
torch.cuda.empty_cache()
diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
index 0353ff115..da0d52d06 100644
--- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
+++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
@@ -68,8 +68,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_from_pretrained()
diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py
index eeb04df0f..0b9a1605c 100644
--- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py
@@ -61,7 +61,7 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
def run_dist(rank, world_size, port):
- colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_torch_ddp_checkpointIO()
diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
index 1ea70368e..12b70cc04 100644
--- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
@@ -141,7 +141,7 @@ def check_torch_fsdp_ckpt():
def run_dist(rank, world_size, port):
# init dist env
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_torch_fsdp_ckpt()
diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py
index ab61cdae5..5d140064b 100644
--- a/tests/test_cluster/test_device_mesh_manager.py
+++ b/tests/test_cluster/test_device_mesh_manager.py
@@ -6,7 +6,7 @@ from colossalai.testing import spawn
def check_device_mesh_manager(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
device_mesh_manager = DeviceMeshManager()
# TODO(ver217): this test is strictly relies on hardware, temporary skip it
# device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],)
diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py
index 3d206622d..3071c0f59 100644
--- a/tests/test_cluster/test_process_group_mesh.py
+++ b/tests/test_cluster/test_process_group_mesh.py
@@ -6,57 +6,6 @@ from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import spawn
-def check_process_group_mesh_with_gpc():
- from colossalai.legacy.context import ParallelMode
- from colossalai.legacy.core import global_context as gpc
-
- DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
- pg_mesh = ProcessGroupMesh(1, 2, 2)
-
- # check world size
- assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size(
- TP_DIM
- ), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}"
- assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM)
- assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM)
-
- # check locak rank (coordinate)
- assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate(
- TP_DIM
- ), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}"
- assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM)
- assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM)
-
- # check ranks in group
- tp_group = pg_mesh.get_group_along_axis(TP_DIM)
- assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group)
- pp_group = pg_mesh.get_group_along_axis(PP_DIM)
- assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group)
- dp_group = pg_mesh.get_group_along_axis(DP_DIM)
- assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group)
-
- # check prev rank
- coord = pg_mesh.coordinate()
- if not gpc.is_first_rank(ParallelMode.TENSOR):
- assert coord[TP_DIM] != 0
- prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :]
- assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape)
- if not gpc.is_first_rank(ParallelMode.PIPELINE):
- assert coord[PP_DIM] != 0
- prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :]
- assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape)
-
- # check next rank
- if not gpc.is_last_rank(ParallelMode.TENSOR):
- assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1
- next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :]
- assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape)
- if not gpc.is_last_rank(ParallelMode.PIPELINE):
- assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1
- next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :]
- assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape)
-
-
def check_process_group_mesh_with_cases():
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2
@@ -177,14 +126,11 @@ def check_process_group_mesh_with_cases():
def run_dist(rank, world_size, port):
colossalai.launch(
- config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))),
rank=rank,
world_size=world_size,
port=port,
host="localhost",
)
- # TODO(ver217): this function should be removed when gpc is removed
- # check_process_group_mesh_with_gpc()
check_process_group_mesh_with_cases()
diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py
index f4a88f79c..3d9c6d7ce 100644
--- a/tests/test_device/test_alpha_beta.py
+++ b/tests/test_device/test_alpha_beta.py
@@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_alpha_beta(rank, world_size, port, physical_devices):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
profiler = AlphaBetaProfiler(physical_devices)
ab_dict = profiler.profile_ab()
for _, (alpha, beta) in ab_dict.items():
diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py
index af44af5d9..b2d057273 100644
--- a/tests/test_device/test_device_mesh.py
+++ b/tests/test_device/test_device_mesh.py
@@ -75,7 +75,7 @@ def check_2d_device_mesh():
def check_init_from_process_group(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
@pytest.mark.dist
diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py
index 34f2aacc1..7633f59b9 100644
--- a/tests/test_device/test_extract_alpha_beta.py
+++ b/tests/test_device/test_extract_alpha_beta.py
@@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_extract_alpha_beta(rank, world_size, port, physical_devices):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
profiler = AlphaBetaProfiler(physical_devices)
mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py
index 3b398a917..d93f65698 100644
--- a/tests/test_device/test_init_logical_pg.py
+++ b/tests/test_device/test_init_logical_pg.py
@@ -9,7 +9,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_layer(rank, world_size, port):
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
assert rank == dist.get_rank()
diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py
index d9d4e79c1..a44b8e3d6 100644
--- a/tests/test_device/test_search_logical_device_mesh.py
+++ b/tests/test_device/test_search_logical_device_mesh.py
@@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_alpha_beta(rank, world_size, port, physical_devices):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
profiler = AlphaBetaProfiler(physical_devices)
best_logical_mesh = profiler.search_best_logical_mesh()
diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
index 10fe98155..8a3e2d6ec 100644
--- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
+++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
@@ -64,7 +64,7 @@ class MyModule(torch.nn.Module):
def _run_act_ckpt_codegen(rank, world_size, port):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model and run forward
model = MyModule()
@@ -127,7 +127,7 @@ def test_act_ckpt_codegen():
def _run_act_ckpt_python_code_torch11(rank, world_size, port):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model and run forward
model = MyModule()
diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py
index f1e87e5ed..69767db2d 100644
--- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py
+++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py
@@ -32,7 +32,7 @@ class MyModule(torch.nn.Module):
def _run_act_ckpt_codegen(rank, world_size, port):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model and run forward
model = MyModule()
@@ -96,7 +96,7 @@ def test_act_ckpt_codegen():
def _run_act_ckpt_python_code_torch11(rank, world_size, port):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model and run forward
model = MyModule()
diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py
index da1e73ec3..9df4a6899 100644
--- a/tests/test_fx/test_codegen/test_offload_codegen.py
+++ b/tests/test_fx/test_codegen/test_offload_codegen.py
@@ -66,7 +66,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T
def _run_offload_codegen(rank, world_size, port):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model and input
model = MyNet().cuda()
@@ -124,7 +124,7 @@ def test_act_ckpt_codegen():
def _run_offload_codegen_torch11(rank, world_size, port):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currently
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model and input
model = MyNet().cuda()
diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py
index 6d890f59d..6b0e12609 100644
--- a/tests/test_fx/test_parallel_1d.py
+++ b/tests/test_fx/test_parallel_1d.py
@@ -33,7 +33,7 @@ CONFIG = dict(parallel=dict(tensor=dict(mode="1d", size=2)))
def check_layer(rank, world_size, port):
disable_existing_loggers()
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
input_tensor = torch.rand(2, 16).cuda()
model = MLP(16).cuda()
symbolic_traced = symbolic_trace(model)
diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py
index 046ee932d..cc0389af9 100755
--- a/tests/test_infer/test_config_and_struct.py
+++ b/tests/test_infer/test_config_and_struct.py
@@ -80,7 +80,7 @@ def check_config_and_inference():
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_config_and_inference()
diff --git a/tests/test_infer/test_cuda_graph.py b/tests/test_infer/test_cuda_graph.py
index a0a55d3ad..4cdc62fbe 100644
--- a/tests/test_infer/test_cuda_graph.py
+++ b/tests/test_infer/test_cuda_graph.py
@@ -80,7 +80,7 @@ def check_output_consistency(batch_size):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency(32)
check_output_consistency(64)
check_output_consistency(128)
diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py
index 25413a292..a0ddbbc7b 100644
--- a/tests/test_infer/test_inference_engine.py
+++ b/tests/test_infer/test_inference_engine.py
@@ -157,7 +157,7 @@ def check_spec_dec(num_layers, max_length):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
diff --git a/tests/test_infer/test_ops/__init__.py b/tests/test_infer/test_kernels/__init__.py
similarity index 100%
rename from tests/test_infer/test_ops/__init__.py
rename to tests/test_infer/test_kernels/__init__.py
diff --git a/tests/test_infer/test_ops/cuda/__init__.py b/tests/test_infer/test_kernels/cuda/__init__.py
similarity index 100%
rename from tests/test_infer/test_ops/cuda/__init__.py
rename to tests/test_infer/test_kernels/cuda/__init__.py
diff --git a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py
similarity index 98%
rename from tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py
rename to tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py
index b3bd503bb..80a5d067b 100644
--- a/tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py
+++ b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py
@@ -7,11 +7,11 @@ import torch
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
+from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
inference_ops = InferenceOpsLoader().load()
-from tests.test_infer.test_ops.triton.kernel_utils import (
+from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v3,
diff --git a/tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py b/tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py
similarity index 95%
rename from tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py
rename to tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py
index c632cfe30..b6ba1a01b 100644
--- a/tests/test_infer/test_ops/cuda/test_get_cos_and_sin.py
+++ b/tests/test_infer/test_kernels/cuda/test_get_cos_and_sin.py
@@ -3,7 +3,7 @@ import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
-from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
+from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
inference_ops = InferenceOpsLoader().load()
diff --git a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py
similarity index 97%
rename from tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py
rename to tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py
index e9c99ddc7..d90f64690 100644
--- a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py
+++ b/tests/test_infer/test_kernels/cuda/test_kv_cache_memcpy.py
@@ -4,7 +4,10 @@ import torch.nn.functional as F
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token
+from tests.test_infer.test_kernels.triton.kernel_utils import (
+ generate_caches_and_block_tables_v3,
+ mock_alloc_single_token,
+)
inference_ops = InferenceOpsLoader().load()
diff --git a/tests/test_infer/test_ops/cuda/test_rms_layernorm.py b/tests/test_infer/test_kernels/cuda/test_rms_layernorm.py
similarity index 100%
rename from tests/test_infer/test_ops/cuda/test_rms_layernorm.py
rename to tests/test_infer/test_kernels/cuda/test_rms_layernorm.py
diff --git a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py
similarity index 96%
rename from tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py
rename to tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py
index 501bf65d8..8237384c0 100644
--- a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py
+++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py
@@ -7,8 +7,8 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
inference_ops = InferenceOpsLoader().load()
-from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
-from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
+from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
+from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb
def numpy_allclose(x, y, rtol, atol):
diff --git a/tests/test_infer/test_ops/cuda/test_silu_and_mul.py b/tests/test_infer/test_kernels/cuda/test_silu_and_mul.py
similarity index 100%
rename from tests/test_infer/test_ops/cuda/test_silu_and_mul.py
rename to tests/test_infer/test_kernels/cuda/test_silu_and_mul.py
diff --git a/tests/test_infer/test_ops/triton/__init__.py b/tests/test_infer/test_kernels/triton/__init__.py
similarity index 100%
rename from tests/test_infer/test_ops/triton/__init__.py
rename to tests/test_infer/test_kernels/triton/__init__.py
diff --git a/tests/test_infer/test_ops/triton/kernel_utils.py b/tests/test_infer/test_kernels/triton/kernel_utils.py
similarity index 100%
rename from tests/test_infer/test_ops/triton/kernel_utils.py
rename to tests/test_infer/test_kernels/triton/kernel_utils.py
diff --git a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py
similarity index 99%
rename from tests/test_infer/test_ops/triton/test_context_attn_unpad.py
rename to tests/test_infer/test_kernels/triton/test_context_attn_unpad.py
index 76785d530..e34fada97 100644
--- a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py
+++ b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py
@@ -5,7 +5,7 @@ from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import context_attention_unpadded
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.triton.kernel_utils import (
+from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
diff --git a/tests/test_infer/test_ops/triton/test_decoding_attn.py b/tests/test_infer/test_kernels/triton/test_decoding_attn.py
similarity index 97%
rename from tests/test_infer/test_ops/triton/test_decoding_attn.py
rename to tests/test_infer/test_kernels/triton/test_decoding_attn.py
index 616d7868b..24741fecf 100644
--- a/tests/test_infer/test_ops/triton/test_decoding_attn.py
+++ b/tests/test_infer/test_kernels/triton/test_decoding_attn.py
@@ -6,14 +6,14 @@ from packaging import version
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.triton.kernel_utils import (
+from tests.test_infer.test_kernels.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
torch_attn_ref,
)
-from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
+from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
try:
import triton # noqa
diff --git a/tests/test_infer/test_ops/triton/test_fused_rotary_embedding.py b/tests/test_infer/test_kernels/triton/test_fused_rotary_embedding.py
similarity index 100%
rename from tests/test_infer/test_ops/triton/test_fused_rotary_embedding.py
rename to tests/test_infer/test_kernels/triton/test_fused_rotary_embedding.py
diff --git a/tests/test_infer/test_ops/triton/test_kvcache_copy.py b/tests/test_infer/test_kernels/triton/test_kvcache_copy.py
similarity index 99%
rename from tests/test_infer/test_ops/triton/test_kvcache_copy.py
rename to tests/test_infer/test_kernels/triton/test_kvcache_copy.py
index 95126c087..336eb256b 100644
--- a/tests/test_infer/test_ops/triton/test_kvcache_copy.py
+++ b/tests/test_infer/test_kernels/triton/test_kvcache_copy.py
@@ -4,7 +4,7 @@ from packaging import version
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from colossalai.utils import get_current_device
-from tests.test_infer.test_ops.triton.kernel_utils import (
+from tests.test_infer.test_kernels.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
mock_alloc_single_token,
diff --git a/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py b/tests/test_infer/test_kernels/triton/test_rmsnorm_triton.py
similarity index 100%
rename from tests/test_infer/test_ops/triton/test_rmsnorm_triton.py
rename to tests/test_infer/test_kernels/triton/test_rmsnorm_triton.py
diff --git a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py
similarity index 98%
rename from tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py
rename to tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py
index 87eb38135..570093693 100644
--- a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py
+++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py
@@ -4,7 +4,7 @@ from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.triton import decoding_fused_rotary_embedding
-from tests.test_infer.test_ops.triton.kernel_utils import (
+from tests.test_infer.test_kernels.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
)
diff --git a/tests/test_infer/test_ops/triton/test_xine_copy.py b/tests/test_infer/test_kernels/triton/test_xine_copy.py
similarity index 100%
rename from tests/test_infer/test_ops/triton/test_xine_copy.py
rename to tests/test_infer/test_kernels/triton/test_xine_copy.py
diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py
index 321047706..bca9a1a84 100755
--- a/tests/test_infer/test_kvcache_manager.py
+++ b/tests/test_infer/test_kvcache_manager.py
@@ -164,7 +164,7 @@ def check_cache_manager(test_config):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_cache_manager()
diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py
index 5d6be5cb1..3d6fc3bdb 100644
--- a/tests/test_infer/test_models/test_baichuan.py
+++ b/tests/test_infer/test_models/test_baichuan.py
@@ -14,7 +14,6 @@ from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
@@ -87,7 +86,7 @@ def run_engine(world_size, **kwargs):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
@@ -99,7 +98,7 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
-def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
+def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
@@ -132,7 +131,7 @@ def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
- test_tp_engine()
+ check_tp_engine()
if __name__ == "__main__":
diff --git a/tests/test_infer/test_request_handler.py b/tests/test_infer/test_request_handler.py
index c7a35ebbe..912fdbf11 100644
--- a/tests/test_infer/test_request_handler.py
+++ b/tests/test_infer/test_request_handler.py
@@ -90,7 +90,7 @@ def check_request_handler():
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_running_list()
check_request_handler()
diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py
index d0c4cd0a7..c85860a8d 100644
--- a/tests/test_lazy/test_models.py
+++ b/tests/test_lazy/test_models.py
@@ -7,21 +7,23 @@ from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo
@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0")
@pytest.mark.parametrize(
"subset",
- [COMMON_MODELS]
- if IS_FAST_TEST
- else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"],
+ (
+ [COMMON_MODELS]
+ if IS_FAST_TEST
+ else ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"]
+ ),
)
@pytest.mark.parametrize("default_device", ["cpu", "cuda"])
-def test_torchvision_models_lazy_init(subset, default_device):
+def test_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
- ("transformers_vit", "transformers_blip2")
+ ("transformers_vit", "transformers_blip2", "transformers_whisper")
):
continue
check_lazy_init(entry, verbose=True, default_device=default_device)
if __name__ == "__main__":
- test_torchvision_models_lazy_init("transformers", "cpu")
+ test_models_lazy_init("transformers", "cpu")
diff --git a/tests/test_legacy/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py
index fe16bc4d4..0df6335f5 100644
--- a/tests/test_legacy/test_amp/test_naive_fp16.py
+++ b/tests/test_legacy/test_amp/test_naive_fp16.py
@@ -77,7 +77,7 @@ def run_naive_amp():
def run_dist(rank, world_size, port):
- colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.legacy.launch(rank=rank, world_size=world_size, port=port, host="localhost")
run_naive_amp()
diff --git a/tests/test_legacy/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py
index 5e2e1ede5..dc47dfc72 100644
--- a/tests/test_legacy/test_amp/test_torch_fp16.py
+++ b/tests/test_legacy/test_amp/test_torch_fp16.py
@@ -76,7 +76,7 @@ def run_torch_amp():
def run_dist(rank, world_size, port):
- colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.legacy.launch(rank=rank, world_size=world_size, port=port, host="localhost")
run_torch_amp()
diff --git a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
index bc243631a..bd15e10f3 100644
--- a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
+++ b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py
@@ -16,7 +16,7 @@ torch.manual_seed(123)
def check_layer(rank, world_size, port):
disable_existing_loggers()
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl", verbose=False)
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl", verbose=False)
rank = gpc.get_local_rank(ParallelMode.PIPELINE)
if rank == 0:
diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py
index 079022e93..75955df69 100644
--- a/tests/test_legacy/test_comm/test_comm.py
+++ b/tests/test_legacy/test_comm/test_comm.py
@@ -48,7 +48,7 @@ def check_all_reduce():
def check_layer(rank, world_size, port):
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
assert dist.get_rank() == gpc.get_global_rank()
print("Rank {} / {}".format(dist.get_rank(), dist.get_world_size()))
diff --git a/tests/test_legacy/test_comm/test_object_list_p2p.py b/tests/test_legacy/test_comm/test_object_list_p2p.py
index 69c68c715..1d618a65f 100644
--- a/tests/test_legacy/test_comm/test_object_list_p2p.py
+++ b/tests/test_legacy/test_comm/test_object_list_p2p.py
@@ -88,7 +88,7 @@ def check_send_recv_forward_backward():
def check_layer(rank, world_size, port):
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_send_recv_forward()
check_send_recv_backward()
check_send_recv_forward_backward()
diff --git a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py
index eb05ea483..c272f51f4 100644
--- a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py
+++ b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py
@@ -104,7 +104,7 @@ def check_small_pipeline():
def check_layer(rank, world_size, port):
disable_existing_loggers()
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
disable_existing_loggers()
# check_send_recv_forward()
diff --git a/tests/test_legacy/test_layers/test_1d/test_1d.py b/tests/test_legacy/test_layers/test_1d/test_1d.py
index cebbedd30..9057c2c68 100644
--- a/tests/test_legacy/test_layers/test_1d/test_1d.py
+++ b/tests/test_legacy/test_layers/test_1d/test_1d.py
@@ -17,7 +17,7 @@ CONFIG = dict(
def check_layer(rank, world_size, port):
disable_existing_loggers()
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_linear_col()
check_linear_row()
diff --git a/tests/test_legacy/test_layers/test_2d/test_2d.py b/tests/test_legacy/test_layers/test_2d/test_2d.py
index 77a4b281a..5be498f90 100644
--- a/tests/test_legacy/test_layers/test_2d/test_2d.py
+++ b/tests/test_legacy/test_layers/test_2d/test_2d.py
@@ -50,7 +50,7 @@ def check_layer():
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
diff --git a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
index 437a8f8a7..029274570 100644
--- a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
+++ b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
@@ -38,7 +38,7 @@ def check_layer():
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
diff --git a/tests/test_legacy/test_layers/test_3d/test_3d.py b/tests/test_legacy/test_layers/test_3d/test_3d.py
index 7057e2308..876aa7ba8 100644
--- a/tests/test_legacy/test_layers/test_3d/test_3d.py
+++ b/tests/test_legacy/test_layers/test_3d/test_3d.py
@@ -44,7 +44,7 @@ def check_layer():
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = True
diff --git a/tests/test_legacy/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py
index d64ff56b8..c45097232 100644
--- a/tests/test_legacy/test_layers/test_cache_embedding.py
+++ b/tests/test_legacy/test_layers/test_cache_embedding.py
@@ -378,7 +378,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
def run_dist(rank, world_size, port):
- colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# run_parallel_freq_aware_embed_columnwise(rank, world_size)
run_parallel_freq_aware_embed_tablewise(rank, world_size)
diff --git a/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py
index 506244447..bfedb779c 100644
--- a/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py
+++ b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py
@@ -48,7 +48,7 @@ def check_mem():
def run_dist(rank, world_size, port):
- colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_mem()
run()
diff --git a/tests/test_legacy/test_tensor/test_parameter.py b/tests/test_legacy/test_tensor/test_parameter.py
index 5217e22cc..eae3e0eb3 100644
--- a/tests/test_legacy/test_tensor/test_parameter.py
+++ b/tests/test_legacy/test_tensor/test_parameter.py
@@ -9,7 +9,7 @@ from colossalai.testing import free_port
@pytest.mark.skip
def test_multiinheritance():
- colossalai.legacy.launch(config={}, rank=0, world_size=1, host="localhost", port=free_port(), backend="nccl")
+ colossalai.legacy.launch(rank=0, world_size=1, host="localhost", port=free_port(), backend="nccl")
colo_param = ColoParameter(None, requires_grad=True)
assert colo_param.dist_spec.placement.value == "r"
assert isinstance(colo_param, ColoTensor)
diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
index cab111358..ba8504d06 100644
--- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
+++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
@@ -86,7 +86,7 @@ def check_comm(size, rank, prev_rank, next_rank, logger):
def run_check(rank, world_size, port):
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
logger = get_dist_logger()
rank = gpc.get_global_rank()
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
index cd7fcfe56..ae7b961ae 100644
--- a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
+++ b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
@@ -23,7 +23,7 @@ CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=d
def run_schedule(rank, world_size, port):
- launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# build model
model = resnet18(num_classes=10)
diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py
index c07ff132b..e1b2128aa 100644
--- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py
+++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py
@@ -43,7 +43,7 @@ def check_checkpoint_1d(rank, world_size, port):
)
disable_existing_loggers()
- launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py
index 2ec1facf2..12747951b 100644
--- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py
+++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py
@@ -43,7 +43,7 @@ def check_checkpoint_2d(rank, world_size, port):
)
disable_existing_loggers()
- launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py
index a6bf702a8..f7e7b6fad 100644
--- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py
+++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py
@@ -43,7 +43,7 @@ def check_checkpoint_2p5d(rank, world_size, port):
)
disable_existing_loggers()
- launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py
index 12d928312..05666cc93 100644
--- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py
+++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py
@@ -43,7 +43,7 @@ def check_checkpoint_3d(rank, world_size, port):
)
disable_existing_loggers()
- launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py
index 4993df4f3..30fc17b8e 100644
--- a/tests/test_legacy/test_utils/test_memory.py
+++ b/tests/test_legacy/test_utils/test_memory.py
@@ -14,7 +14,7 @@ def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
def run_dist(rank, world_size, port):
- colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
_run_colo_set_process_memory_fraction_and_colo_device_memory_capacity()
diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py
index 9975cc04f..c5fab49f4 100644
--- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py
+++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py
@@ -62,7 +62,7 @@ def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_ty
def run_dist(rank, world_size, port):
disable_existing_loggers()
- colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_grad_clip_norm(world_size=world_size)
diff --git a/tests/test_legacy/test_zero/test_commons.py b/tests/test_legacy/test_zero/test_commons.py
index 741f519e1..32b15706d 100644
--- a/tests/test_legacy/test_zero/test_commons.py
+++ b/tests/test_legacy/test_zero/test_commons.py
@@ -7,7 +7,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
def run_tensor_move(rank, world_size, port):
- colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.legacy.launch(rank=0, world_size=world_size, host="localhost", port=port, backend="nccl")
src_t = torch.ones(2, 3).cuda()
tgt_t = torch.zeros(2, 3)
diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py
new file mode 100644
index 000000000..b8daf775d
--- /dev/null
+++ b/tests/test_lora/test_lora.py
@@ -0,0 +1,105 @@
+import copy
+import os
+from itertools import product
+
+import torch
+from peft import LoraConfig
+from torch import distributed as dist
+from torch.optim import AdamW
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
+from tests.kit.model_zoo import model_zoo
+from tests.test_checkpoint_io.utils import shared_tempdir
+
+
+@clear_cache_before_run()
+def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
+ model = model_fn()
+ lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
+
+ test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
+ test_configs = [
+ {
+ "lora_config": lora_config,
+ "quantize": False,
+ },
+ {
+ "lora_config": lora_config,
+ "quantize": True,
+ },
+ ]
+ for plugin, test_config in product(test_plugins, test_configs):
+ # checkpoint loaded model
+ model_save = model_fn()
+ model_load = copy.deepcopy(model_save)
+
+ optimizer = AdamW(model.parameters(), lr=0.001)
+ criterion = loss_fn
+
+ booster = Booster(plugin=plugin)
+ model_save = booster.enable_lora(model_save, **test_config)
+ model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)
+
+ with shared_tempdir() as tempdir:
+ lora_ckpt_path = os.path.join(tempdir, "ckpt")
+ booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
+ dist.barrier()
+
+ # The Lora checkpoint should be small in size
+ checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
+ assert checkpoint_size_mb < 1
+
+ model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
+ model_load, _, _, _, _ = booster.boost(model_load)
+
+ check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
+
+ # test fwd bwd correctness
+ test_model = model_load
+ model_copy = copy.deepcopy(model_load)
+
+ data = data_gen_fn()
+ data = {
+ k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
+ }
+
+ output = test_model(**data)
+ output = output_transform_fn(output)
+ loss = criterion(output)
+
+ booster.backward(loss, optimizer)
+ optimizer.clip_grad_by_norm(1.0)
+ optimizer.step()
+
+ for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()):
+ if "lora_" in n1:
+ # lora modules require gradients, thus updated
+ assert p1.requires_grad
+ assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
+ else:
+ if not p1.requires_grad:
+ torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
+
+
+def run_lora_test():
+ sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ task_type = None
+ if name == "transformers_llama_for_casual_lm":
+ task_type = "CAUSAL_LM"
+ if name == "transformers_llama_for_sequence_classification":
+ task_type = "SEQ_CLS"
+ check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_lora_test()
+
+
+@rerun_if_address_is_in_use()
+def test_torch_ddp_lora():
+ spawn(run_dist, 2)
diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py
index a349bc5a9..a88f5f9cc 100644
--- a/tests/test_moe/test_grad_handler.py
+++ b/tests/test_moe/test_grad_handler.py
@@ -16,7 +16,6 @@ DIM = 16
def run_test(rank, world_size, port):
colossalai.launch(
- config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index 62d61a3d4..30122d31a 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -20,7 +20,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# Here we do not need TF32, since it brings absolute error on results
torch.backends.cuda.matmul.allow_tf32 = False
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
local_rank = dist.get_rank()
MOE_MANAGER.setup(parallel="EP") # MOE environment initialization
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index 74feeeb59..660fbd358 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -128,7 +128,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
assert batch_size % world_size == 0
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py
index 2f08a335d..b7be54d26 100644
--- a/tests/test_moe/test_moe_group.py
+++ b/tests/test_moe/test_moe_group.py
@@ -60,7 +60,6 @@ def run_moe_init(expert_parallel):
def _run_test(rank, world_size, port, expert_parallel):
colossalai.launch(
- config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py
index 7ada4090f..7932fa8a7 100644
--- a/tests/test_moe/test_moe_hybrid_zero.py
+++ b/tests/test_moe/test_moe_hybrid_zero.py
@@ -81,7 +81,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1):
def run_dist(rank, world_size, port):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_optim_test(rank, world_size, stage=1)
run_zero_optim_test(rank, world_size, stage=2)
diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py
index 717bb99fb..fae189bac 100644
--- a/tests/test_moe/test_moe_load_balance.py
+++ b/tests/test_moe/test_moe_load_balance.py
@@ -164,7 +164,6 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1):
def run_dist(rank, world_size, port):
colossalai.launch(
- config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py
index 1bff21066..3bb08b49e 100644
--- a/tests/test_moe/test_moe_zero_fwd_bwd.py
+++ b/tests/test_moe/test_moe_zero_fwd_bwd.py
@@ -61,7 +61,7 @@ def run_zero_test(local_rank, stage=1):
def run_dist(rank, world_size, port, stage):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
seed_all(42 + rank)
run_zero_test(rank, stage=stage)
diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py
index 4f6067aaa..224c5c3b9 100644
--- a/tests/test_moe/test_moe_zero_optim.py
+++ b/tests/test_moe/test_moe_zero_optim.py
@@ -66,7 +66,7 @@ def run_zero_test(local_rank, stage=1):
def run_dist(rank, world_size, port, stage):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
seed_all(42 + rank)
run_zero_test(rank, stage=stage)
diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py
index 6d932156a..002649905 100644
--- a/tests/test_optimizer/test_adam_kernel.py
+++ b/tests/test_optimizer/test_adam_kernel.py
@@ -69,7 +69,7 @@ class FusedAdamKernel(AdamKernel):
fused_optim = FusedOptimizerLoader().load()
self.fused_adam = fused_optim.multi_tensor_adam
- self.dummy_overflow_buf = torch.cuda.IntTensor([0])
+ self.dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
multi_tensor_applier(
diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py
index 6f5e734b7..48a8d12e0 100644
--- a/tests/test_pipeline/test_p2p_communication.py
+++ b/tests/test_pipeline/test_p2p_communication.py
@@ -71,7 +71,7 @@ def check_p2p_communication():
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_p2p_communication()
diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py
index 1b7b0073f..e2f71ff89 100644
--- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py
+++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py
@@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
+ self.num_model_chunks = 1
@property
def num_stages(self):
diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py
index 9f8c1ad32..d39c5ea91 100644
--- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py
+++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py
@@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
+ self.num_model_chunks = 1
@property
def num_stages(self):
diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py
index f8820688e..a626b834a 100644
--- a/tests/test_pipeline/test_schedule/test_interleaved.py
+++ b/tests/test_pipeline/test_schedule/test_interleaved.py
@@ -58,7 +58,7 @@ def run_pp(
This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
# create model
seed_all(1453)
diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py
index 590800780..c4bfa7b69 100644
--- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py
+++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py
@@ -148,7 +148,7 @@ def run_dist(
num_microbatch: int,
batch_size: int,
):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
examine_pp(num_microbatch, batch_size)
diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py
index ed8284b3e..5146a86c8 100644
--- a/tests/test_pipeline/test_stage_manager.py
+++ b/tests/test_pipeline/test_stage_manager.py
@@ -64,7 +64,7 @@ def check_stage_manager():
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_stage_manager()
diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py
index f9eab132f..9aa24a166 100644
--- a/tests/test_shardformer/test_flash_attention.py
+++ b/tests/test_shardformer/test_flash_attention.py
@@ -4,11 +4,7 @@ from copy import copy
import torch
from torch.testing import assert_close
-from colossalai.kernel.kernel_loader import (
- FlashAttentionLoader,
- FlashAttentionWithCustomMaskLoader,
- FlashAttentionWithPaddingMaskLoader,
-)
+from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer.attn import invert_mask
from colossalai.testing import clear_cache_before_run, parameterize
@@ -119,11 +115,6 @@ def test_flash_attn_func(dtype: torch.dtype):
if ext.is_available():
ext.assert_compatible()
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
- for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY:
- ext = ext_cls()
- if ext.is_available():
- ext.assert_compatible()
- avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True))
test_sets = {
"none": (lambda dtype: ({}, None), avail_attn_funcs),
diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py
index f652d18e9..b2c81f8ab 100644
--- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py
+++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py
@@ -193,13 +193,13 @@ def run_3d_test(test_config):
def check_grad_clip_norm(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_test()
def check_grad_clip_norm_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_3d_test()
diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py
index a749a2966..ee1fd9333 100644
--- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py
+++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py
@@ -151,13 +151,13 @@ def run_3d_test(test_config):
def check_grad_clip_norm(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_test()
def check_grad_clip_norm_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_3d_test()
diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
index 41f06a4c3..be257e818 100644
--- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
+++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
@@ -183,13 +183,13 @@ def run_3d_test(test_config):
def check_grad_clip_norm(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_test()
def check_grad_clip_norm_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_3d_test()
diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py
index 414157c22..8ace0e028 100644
--- a/tests/test_shardformer/test_layer/test_dist_crossentropy.py
+++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py
@@ -14,7 +14,7 @@ CONFIG = dict(
def check_dist_crossentropy(rank, world_size, port, ignore_index):
disable_existing_loggers()
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
# prepare data
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py
index 576620e6c..f1e646ed2 100644
--- a/tests/test_shardformer/test_layer/test_dropout.py
+++ b/tests/test_shardformer/test_layer/test_dropout.py
@@ -56,7 +56,7 @@ def check_dropout_replicated_input():
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_dropout_parallel_input()
check_dropout_replicated_input()
diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py
index 3dbbcd766..3d7dc2088 100644
--- a/tests/test_shardformer/test_layer/test_embedding.py
+++ b/tests/test_shardformer/test_layer/test_embedding.py
@@ -43,7 +43,7 @@ def check_embedding_1d(lazy_init: bool):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_embedding_1d()
diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
index e9aa0dbed..5aa8584a0 100644
--- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
@@ -143,7 +143,7 @@ def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, ove
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# test for linear conv
check_gpt2_qkv_fused_linear_1d()
diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py
index 3eb3bb2e5..b0deff6b8 100644
--- a/tests/test_shardformer/test_layer/test_layernorm.py
+++ b/tests/test_shardformer/test_layer/test_layernorm.py
@@ -41,7 +41,7 @@ def check_layernorm(lazy_init: bool):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_layernorm()
diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py
index 21d3190de..541aa3251 100644
--- a/tests/test_shardformer/test_layer/test_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_linear_1d.py
@@ -185,7 +185,7 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
def check_dist_linear(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_dist_linear_test()
diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
index 5e996d2ba..dc14fd591 100644
--- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
@@ -126,7 +126,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# test for linear conv
check_linear_conv_1d_col()
diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py
index 13b1a13e7..a6cf61f8f 100644
--- a/tests/test_shardformer/test_layer/test_sequence_parallel.py
+++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py
@@ -165,7 +165,7 @@ def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
def check_all2all_attn(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_seq_parallel_attn()
diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py
index b23a44f2d..fdd304256 100644
--- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py
+++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py
@@ -21,7 +21,7 @@ def check_vocab_embedding_1d(lazy_init: bool):
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None)
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
- assert dist_embedding_1d.num_embeddings == 64
+ assert dist_embedding_1d.num_embeddings == 128
assert dist_embedding_1d.embedding_dim == 32
assert embedding_copy.weight is dist_embedding_1d.weight
@@ -45,7 +45,7 @@ def check_vocab_embedding_1d(lazy_init: bool):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_vocab_embedding_1d()
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index d5fc2c30f..1835a5c8e 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -14,12 +14,14 @@ from torch.testing import assert_close
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
+from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
+from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
def build_model(
@@ -225,7 +227,7 @@ def check_output_hidden_state(
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
- assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
+ assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
def check_weight(
@@ -247,11 +249,10 @@ def check_weight(
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
- sharded_weight_list = [
- torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
- ]
- dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
- sharded_weight = torch.cat(sharded_weight_list, dim=dim)
+ sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False)
+
+ if is_padded_tensor(sharded_weight):
+ sharded_weight = to_unpadded_tensor(sharded_weight)
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py
index 919557797..3ec394768 100644
--- a/tests/test_shardformer/test_model/test_shard_bert.py
+++ b/tests/test_shardformer/test_model/test_shard_bert.py
@@ -231,13 +231,13 @@ def run_bert_3d_test(test_config):
def check_bert(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_bert_test()
def check_bert_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_bert_3d_test()
diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py
index 2c56b0435..712c5c1e1 100644
--- a/tests/test_shardformer/test_model/test_shard_blip2.py
+++ b/tests/test_shardformer/test_model/test_shard_blip2.py
@@ -99,7 +99,6 @@ def run_blip2_test(
def check_blip2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py
index cc0786618..6ab0369e0 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom.py
@@ -209,13 +209,13 @@ def run_bloom_3d_test(test_config):
def check_bloom(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_bloom_test()
def check_bloom_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_bloom_3d_test()
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py
index 405ceba32..6ce020b68 100644
--- a/tests/test_shardformer/test_model/test_shard_chatglm2.py
+++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py
@@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
+ check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
@@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
# TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
- # if org_model.__class__.__name__ == "ChatGLMModel":
- # check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
+ if org_model.__class__.__name__ == "ChatGLMModel":
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
@@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 4,
"pp_size": 1,
- "enable_all_optimization": True,
+ "enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
- "enable_all_optimization": True,
+ "enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
@@ -258,7 +259,6 @@ def run_chatglm_3d_test(test_config):
def check_chatglm(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
@@ -271,7 +271,6 @@ def check_chatglm(rank, world_size, port):
def check_chatglm_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
diff --git a/tests/test_shardformer/test_model/test_shard_falcon.py b/tests/test_shardformer/test_model/test_shard_falcon.py
index 5e2efcd80..8074f9d61 100644
--- a/tests/test_shardformer/test_model/test_shard_falcon.py
+++ b/tests/test_shardformer/test_model/test_shard_falcon.py
@@ -176,13 +176,13 @@ def run_falcon_3d_test(test_config):
def check_falcon(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_falcon_test()
def check_falcon_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_falcon_3d_test()
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index 4aac7f3d4..72ea2b089 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -275,7 +275,6 @@ def run_gpt2_3d_test(test_config):
def check_gpt2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
@@ -288,7 +287,6 @@ def check_gpt2(rank, world_size, port):
def check_gpt2_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index 27f904292..104ede981 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -32,7 +32,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
model_fn, loss_fn, test_config
)
if enable_gradient_checkpointing:
- org_model.gradient_checkpointing_enable()
+ # org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable()
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
@@ -217,9 +217,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": False,
"precision": "fp32",
"enable_gradient_checkpointing": True,
- "gradient_checkpoint_config": PipelineGradientCheckpointConfig(
- num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
- ),
+ "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
},
{
"tp_size": 4,
@@ -303,9 +301,6 @@ def run_llama_test(test_config):
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
- num_stages=2,
- num_model_chunks=2,
- num_model_layers=8,
num_ckpt_layers_per_stage=[0, 1, 2, 2],
),
},
@@ -324,13 +319,13 @@ def run_llama_3d_test(test_config):
def check_llama(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_llama_test()
def check_llama_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_llama_3d_test()
diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py
index 07bc91b33..deced9d56 100644
--- a/tests/test_shardformer/test_model/test_shard_mistral.py
+++ b/tests/test_shardformer/test_model/test_shard_mistral.py
@@ -91,7 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
- atol, rtol = 1e-4, 1e-3
+ atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
@@ -114,6 +114,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
+ {
+ "tp_size": 1,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "enable_all_optimization": True,
+ "use_lazy_init": False,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 2,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "enable_all_optimization": True,
+ "use_lazy_init": True,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
{
"tp_size": 4,
"pp_size": 1,
@@ -152,11 +170,10 @@ def run_mistral_test(test_config):
def check_mistral(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mistral_test()
-@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py
index 523ed879b..b7c77d20b 100644
--- a/tests/test_shardformer/test_model/test_shard_opt.py
+++ b/tests/test_shardformer/test_model/test_shard_opt.py
@@ -233,7 +233,6 @@ def run_opt_3d_test(test_config):
def check_OPTModel(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
@@ -246,7 +245,6 @@ def check_OPTModel(rank, world_size, port):
def check_opt_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py
index a8d4cb635..e872d7f7b 100644
--- a/tests/test_shardformer/test_model/test_shard_sam.py
+++ b/tests/test_shardformer/test_model/test_shard_sam.py
@@ -57,7 +57,7 @@ def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_f
def check_sam(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_sam_test()
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index 9b22d54d7..521dc9130 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -73,7 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights
if test_config["precision"] == "fp32":
- atol, rtol = 5e-4, 1e-3
+ # TODO he precision in weight checking is too significant.
+ atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
@@ -221,7 +222,6 @@ def run_t5_3d_test(test_config):
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
@@ -234,7 +234,6 @@ def check_t5(rank, world_size, port):
def check_t5_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(
- config={},
rank=rank,
world_size=world_size,
host="localhost",
diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py
index 3a8af2d6d..d33b52b42 100644
--- a/tests/test_shardformer/test_model/test_shard_vit.py
+++ b/tests/test_shardformer/test_model/test_shard_vit.py
@@ -168,13 +168,13 @@ def run_vit_3d_test(test_config):
def check_vit(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_vit_test()
def check_vit_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_vit_3d_test()
diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py
index 6efb8a922..beb2a6761 100644
--- a/tests/test_shardformer/test_model/test_shard_whisper.py
+++ b/tests/test_shardformer/test_model/test_shard_whisper.py
@@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 2,
"enable_metadata_cache": False,
"enable_all_optimization": True,
- "use_lazy_init": True,
+ "use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
@@ -196,13 +196,13 @@ def run_whisper_3d_test(test_config):
def check_whisper(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_whisper_test()
def check_whisper_3d(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_whisper_3d_test()
diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py
index 4b741c21b..4735df717 100644
--- a/tests/test_shardformer/test_with_torch_ddp.py
+++ b/tests/test_shardformer/test_with_torch_ddp.py
@@ -71,7 +71,7 @@ def check_shardformer_with_ddp(lazy_init: bool):
def run_dist(rank, world_size, port):
disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_shardformer_with_ddp()
diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py
index 5e969b1aa..a2414d949 100644
--- a/tests/test_tensor/test_comm_spec_apply.py
+++ b/tests/test_tensor/test_comm_spec_apply.py
@@ -178,7 +178,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
def check_comm(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
assert rank == dist.get_rank()
diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py
index 6d1640b4f..fd9996710 100644
--- a/tests/test_tensor/test_dtensor/test_comm_spec.py
+++ b/tests/test_tensor/test_dtensor/test_comm_spec.py
@@ -124,7 +124,7 @@ def check_all_reduce_bwd(process_groups_dict, rank):
def check_comm(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
assert rank == dist.get_rank()
diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py
index 33ae59d01..60efa315e 100644
--- a/tests/test_tensor/test_dtensor/test_dtensor.py
+++ b/tests/test_tensor/test_dtensor/test_dtensor.py
@@ -21,7 +21,7 @@ class TestModel(torch.nn.Module):
def check_dtensor(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
test_model = TestModel(8, 8).to("cuda")
original_tensor = torch.rand(4, 8).to("cuda")
compare_output = test_model(original_tensor)
diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py
index 3bface1d2..6e426d0e8 100644
--- a/tests/test_tensor/test_dtensor/test_layout_converter.py
+++ b/tests/test_tensor/test_dtensor/test_layout_converter.py
@@ -20,7 +20,7 @@ mesh_shape = (2, 2)
def check_one_step_transform(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# [[0, 1],
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
@@ -82,7 +82,7 @@ def check_one_step_transform(rank, world_size, port):
def check_layout_converting(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
@@ -141,7 +141,7 @@ def check_layout_converting(rank, world_size, port):
def check_layout_converting_apply(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py
index 7d6f8979d..6dbbe5de6 100644
--- a/tests/test_tensor/test_mix_gather.py
+++ b/tests/test_tensor/test_mix_gather.py
@@ -296,7 +296,7 @@ def check_two_all_gather_RS01(device_mesh, rank):
def check_comm(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 8)
assert rank == dist.get_rank()
diff --git a/tests/test_tensor/test_padded_tensor.py b/tests/test_tensor/test_padded_tensor.py
new file mode 100644
index 000000000..6d19845df
--- /dev/null
+++ b/tests/test_tensor/test_padded_tensor.py
@@ -0,0 +1,46 @@
+import torch
+
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global
+from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+
+
+def check_padded_tensor(rank, world_size, port):
+ disable_existing_loggers()
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ original_tensor = torch.rand(32, 64).to("cuda")
+
+ device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
+ target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
+ d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
+
+ padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
+ assert padded_tensor.dist_layout == d_tensor.dist_layout
+
+ tensor_copy = padded_tensor.clone()
+ assert is_padded_tensor(tensor_copy)
+ assert is_distributed_tensor(tensor_copy)
+
+ tensor_detached = padded_tensor.detach()
+ assert is_padded_tensor(tensor_detached)
+ assert is_distributed_tensor(tensor_detached)
+
+ unpadded_tensor = to_unpadded_tensor(padded_tensor)
+ assert unpadded_tensor.shape == d_tensor.shape
+ assert is_distributed_tensor(unpadded_tensor)
+
+ global_tensor = to_global(unpadded_tensor)
+ assert global_tensor.shape == original_tensor.shape
+
+
+@rerun_if_address_is_in_use()
+def test_padded_tensor():
+ world_size = 4
+ spawn(check_padded_tensor, world_size)
+
+
+if __name__ == "__main__":
+ test_padded_tensor()
diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py
index b2bc84edd..8d8d8ef51 100644
--- a/tests/test_tensor/test_shape_consistency_apply.py
+++ b/tests/test_tensor/test_shape_consistency_apply.py
@@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_apply(rank, world_size, port):
disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py
index 879eeccde..412a95f6a 100644
--- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py
+++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py
@@ -49,7 +49,7 @@ def exam_chunk_memory(keep_gathered, pin_memory):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_chunk_memory()
diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py
index e4dc569b8..257311328 100644
--- a/tests/test_zero/test_gemini/test_chunkv2.py
+++ b/tests/test_zero/test_gemini/test_chunkv2.py
@@ -108,7 +108,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_chunk_basic()
diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py
index 3a9742e01..d9084fd5a 100644
--- a/tests/test_zero/test_gemini/test_fwd_bwd.py
+++ b/tests/test_zero/test_gemini/test_fwd_bwd.py
@@ -100,8 +100,7 @@ def exam_gpt_fwd_bwd(
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_gpt_fwd_bwd()
diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py
index 90ad62d1a..1e49f2851 100644
--- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py
+++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py
@@ -80,8 +80,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_gemini_use_rmt()
diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py
index 36a803492..fd0e9fd7c 100644
--- a/tests/test_zero/test_gemini/test_grad_accum.py
+++ b/tests/test_zero/test_gemini/test_grad_accum.py
@@ -138,8 +138,7 @@ def exam_gemini_grad_acc(
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_gemini_grad_acc()
diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py
index 23b3504fd..0a9bac092 100644
--- a/tests/test_zero/test_gemini/test_grad_clip.py
+++ b/tests/test_zero/test_gemini/test_grad_clip.py
@@ -117,8 +117,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_grad_clipping()
diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py
index 7f3c7176e..e54804fc5 100644
--- a/tests/test_zero/test_gemini/test_inference.py
+++ b/tests/test_zero/test_gemini/test_inference.py
@@ -107,8 +107,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_inference()
diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py
index 71bb27b4a..a9366e7bc 100644
--- a/tests/test_zero/test_gemini/test_optim.py
+++ b/tests/test_zero/test_gemini/test_optim.py
@@ -183,8 +183,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_model_step()
exam_tiny_example()
diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py
index cf3658bf9..9c8c497f3 100644
--- a/tests/test_zero/test_gemini/test_search.py
+++ b/tests/test_zero/test_gemini/test_search.py
@@ -47,7 +47,7 @@ def exam_chunk_manager():
def run_dist(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_search_chunk_size()
exam_chunk_manager()
diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py
index cbf5169fc..23e2d8083 100644
--- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py
+++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py
@@ -76,8 +76,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py
index 87cb1cdfe..8d70ae3b1 100644
--- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py
+++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py
@@ -68,8 +68,7 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered):
def run_dist(rank, world_size, port):
- config = {}
- colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_zero_optim_state_dict()
diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py
index 11f738615..ed12bb72d 100644
--- a/tests/test_zero/test_low_level/test_grad_acc.py
+++ b/tests/test_zero/test_low_level/test_grad_acc.py
@@ -130,7 +130,7 @@ def exam_zero_1_grad_acc(sync):
def run_dist(rank, world_size, port):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
exam_zero_1_grad_acc(sync=True)
exam_zero_1_grad_acc(sync=False)
diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py
index e2196cfbf..06a29bd1d 100644
--- a/tests/test_zero/test_low_level/test_zero1_2.py
+++ b/tests/test_zero/test_low_level/test_zero1_2.py
@@ -178,7 +178,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
def run_dist(rank, world_size, port):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
exam_zero_1_torch_ddp(world_size=world_size)
exam_zero_1_2()
diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py
index e9fc8598a..8543dfba0 100644
--- a/tests/test_zero/test_low_level/test_zero_ckpt.py
+++ b/tests/test_zero/test_low_level/test_zero_ckpt.py
@@ -103,7 +103,7 @@ def exam_zero_1_torch_ddp_ckpt():
def run_dist(rank, world_size, port):
- colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
exam_zero_1_torch_ddp_ckpt()
diff --git a/version.txt b/version.txt
index 449d7e73a..0f8268533 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.3.6
+0.3.7