From 29695cf70c2652e4017bd76ff6337572f5b05035 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 4 Mar 2024 16:18:13 +0800 Subject: [PATCH 01/43] [example]add gpt2 benchmark example script. (#5295) * benchmark gpt2 * fix fix fix fix * [doc] fix typo in Colossal-LLaMA-2/README.md (#5247) * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed ddp test (#5254) * [ci] fixed ddp test * polish * fix typo in applications/ColossalEval/README.md (#5250) * [ci] fix shardformer tests. (#5255) * fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen * [doc] fix doc typo (#5256) * [doc] fix annotation display * [doc] fix llama2 doc * [hotfix]: add pp sanity check and fix mbs arg (#5268) * fix: fix misleading mbs arg * feat: add pp sanity check * fix: fix 1f1b sanity check * [workflow] fixed incomplete bash command (#5272) * [workflow] fixed oom tests (#5275) * [workflow] fixed oom tests * polish * polish * polish * [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276) * fix ci fix * fix test * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests * fix --------- Co-authored-by: Wenhao Chen * [shardformer] hybridparallelplugin support gradients accumulation. (#5246) * support gradients acc fix fix fix fix fix fix fix fix fix fix fix fix fix * fix fix * fix fix fix * [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230) * fix auto loading gpt2 tokenizer (#5279) * [doc] add llama2-13B disyplay (#5285) * Update README.md * fix 13b typo --------- Co-authored-by: binmakeswell * fix llama pretrain (#5287) * fix * fix * fix fix * fix fix fix * fix fix * benchmark gpt2 * fix fix fix fix * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * fix fix * fix fix fix * fix * fix fix fix fix fix * fix * Update shardformer.py --------- Co-authored-by: digger yu Co-authored-by: Frank Lee Co-authored-by: Wenhao Chen Co-authored-by: binmakeswell Co-authored-by: Zhongkai Zhao Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com> Co-authored-by: Desperado-Jia <502205863@qq.com> --- .github/workflows/build_on_pr.yml | 2 +- .github/workflows/build_on_schedule.yml | 2 +- .../booster/plugin/hybrid_parallel_plugin.py | 3 + colossalai/shardformer/layer/_operation.py | 27 +-- colossalai/shardformer/modeling/gpt2.py | 93 ++++++- colossalai/shardformer/policies/gpt2.py | 82 +++++-- colossalai/shardformer/shard/shardformer.py | 4 + examples/__init__.py | 0 examples/language/__init__.py | 0 examples/language/{llama2 => }/data_utils.py | 2 +- .../gpt/hybridparallelism/benchmark.py | 228 ++++++++++++++++++ examples/language/llama2/benchmark.py | 3 + examples/language/{llama2 => }/model_utils.py | 0 .../{llama2 => }/performance_evaluator.py | 0 tests/kit/model_zoo/registry.py | 2 +- .../test_plugin/test_3d_plugin.py | 2 +- .../test_plugin/test_gemini_plugin.py | 2 +- .../test_gemini_checkpoint_io.py | 2 +- tests/test_lazy/test_models.py | 2 +- .../test_gpt2_qkv_fused_linear_1d.py | 5 +- .../test_layer/test_linear_1d.py | 3 + .../test_layer/test_qkv_fused_linear_1d.py | 5 +- 22 files changed, 421 insertions(+), 48 deletions(-) create mode 100644 examples/__init__.py create mode 100644 examples/language/__init__.py rename examples/language/{llama2 => }/data_utils.py (99%) create mode 100644 examples/language/gpt/hybridparallelism/benchmark.py rename examples/language/{llama2 => }/model_utils.py (100%) rename examples/language/{llama2 => }/performance_evaluator.py (100%) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 2cad504f3..b01d15490 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -201,4 +201,4 @@ jobs: uses: actions/upload-artifact@v3 with: name: report - path: report/ + path: report/ \ No newline at end of file diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index ae1a5275e..510665b46 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -83,4 +83,4 @@ jobs: SERVER_URL: ${{github.server_url }} REPO: ${{ github.repository }} RUN_ID: ${{ github.run_id }} - WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} \ No newline at end of file diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index da67e6b41..bf677e052 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -36,6 +36,8 @@ from .pp_plugin_base import PipelinePluginBase DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + def _convert_floating_point(x, dtype: torch.dtype = torch.float16): if isinstance(x, torch.Tensor) and torch.is_floating_point(x): @@ -1059,6 +1061,7 @@ class HybridParallelPlugin(PipelinePluginBase): overlap_communication=overlap_communication, cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), + forced_dtype=PRECISION_TORCH_TYPE[precision], ) self.max_norm = max_norm diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 4bca335c8..d4960c7e4 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -9,6 +9,7 @@ except: try: import fused_weight_gradient_mlp_cuda + _grad_accum_fusion_available = True except ImportError: _grad_accum_fusion_available = False @@ -78,7 +79,8 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) - bias = bias.view(bias.shape) + if bias is not None: + bias = bias.view(bias.shape) total_input = input grad_input = grad_output.matmul(weight.T) @@ -91,9 +93,8 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -115,7 +116,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function): ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce - if bias is not None: output = F.linear(input_, weight, bias) else: @@ -143,9 +143,8 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad @@ -228,9 +227,8 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_.shape, dtype=input_parallel.dtype, device=input_parallel.device ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # reduce-scatter scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad @@ -394,9 +392,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_.shape, dtype=input_parallel.dtype, device=input_parallel.device ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # reduce-scatter scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -431,7 +428,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_parallel = torch.cat(tensor_list, dim=dim).contiguous() # calculate gradient if len(input_parallel.shape) > 2: - input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) grad_weight = input_parallel.t().matmul(grad_output) # wait until reduce-scatter finished reducescatter_handle.wait() diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 055e3096d..3e5cc6015 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -24,6 +24,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig +from ..layer import cross_entropy_1d + class GPT2PipelineForwards: """ @@ -326,7 +328,15 @@ class GPT2PipelineForwards: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + if shard_config.enable_tensor_parallelism: + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + else: + loss = loss_fct(shift_logits, shift_labels) + if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1006,3 +1016,84 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ) return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import GPT2LMHeadModel + + def forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + if shard_config.enable_tensor_parallelism: + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + else: + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 022e6ff5b..303766993 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -5,7 +5,12 @@ from torch import Tensor, nn import colossalai.shardformer.layer as col_nn -from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn +from ..modeling.gpt2 import ( + GPT2PipelineForwards, + get_gpt2_flash_attention_forward, + get_lm_forward_with_dist_cross_entropy, + gpt2_sequence_parallel_forward_fn, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -87,9 +92,7 @@ class GPT2Policy(Policy): SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -167,15 +170,35 @@ class GPT2Policy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.wte) - held_layers.append(module.wpe) - held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.ln_f) + else: + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -189,13 +212,27 @@ class GPT2Policy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - } + stage_manager.stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + else: + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -232,9 +269,10 @@ class GPT2LMHeadModelPolicy(GPT2Policy): GPT2LMHeadModel: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": False} ) - ] + ], + method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) } module_policy.update(addon_module) @@ -249,7 +287,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): + if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) return held_layers diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 7a0d75bf2..b132f47fd 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,3 +1,4 @@ +import os from typing import Dict, List, Tuple import torch.nn as nn @@ -9,6 +10,9 @@ from ..policies.base_policy import Policy from .shard_config import ShardConfig from .sharder import ModelSharder +# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + class ShardFormer: """ diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/language/__init__.py b/examples/language/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/language/llama2/data_utils.py b/examples/language/data_utils.py similarity index 99% rename from examples/language/llama2/data_utils.py rename to examples/language/data_utils.py index 6b9e8ef28..ec849ef9d 100644 --- a/examples/language/llama2/data_utils.py +++ b/examples/language/data_utils.py @@ -121,4 +121,4 @@ class RandomDataset(Dataset): "input_ids": self.input_ids[idx], "attention_mask": self.attention_mask[idx], "labels": self.input_ids[idx], - } + } \ No newline at end of file diff --git a/examples/language/gpt/hybridparallelism/benchmark.py b/examples/language/gpt/hybridparallelism/benchmark.py new file mode 100644 index 000000000..1315deae6 --- /dev/null +++ b/examples/language/gpt/hybridparallelism/benchmark.py @@ -0,0 +1,228 @@ +import argparse +import resource +from contextlib import nullcontext + +import torch +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision +from torch.optim import Adam +from tqdm import tqdm +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + +import colossalai + +# import colossalai.utils.device as device_utils +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.utils import get_current_device +from examples.language.data_utils import RandomDataset +from examples.language.model_utils import format_numel_str, get_model_numel +from examples.language.performance_evaluator import PerformanceEvaluator + +# ============================== +# Constants +# ============================== +MODEL_CONFIGS = { + "118M": GPT2Config(activation_function="gelu"), + "338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"), + "738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"), + "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"), +} + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="6.21B", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=200, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=3, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1) + parser.add_argument("--zero", type=int, default=0) + parser.add_argument("--pp_style", type=str, default="1f1b") + parser.add_argument("--num_model_chunks", type=int, default=2) + parser.add_argument("--cpu_offload", action="store_true", help="Use gradient checkpointing") + args = parser.parse_args() + + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + def empty_init(): + pass + + # ============================== + # Initialize Booster + # ============================== + use_empty_init = True + if args.plugin == "gemini": + plugin = GeminiPlugin( + precision="bf16", + shard_param_frac=args.shard_param_frac, + offload_optim_frac=args.offload_optim_frac, + offload_param_frac=args.offload_param_frac, + tp_size=args.tp, + extra_dp_size=args.extra_dp, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + placement_policy="auto", + precision="bf16", + warmup_non_model_data_ratio=args.warmup_ratio, + tp_size=args.tp, + extra_dp_size=args.extra_dp, + ) + elif args.plugin == "fsdp": + if use_empty_init: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + param_init_fn=empty_init(), + ) + else: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ) + ) + elif args.plugin == "fsdp_cpu": + if use_empty_init: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + cpu_offload=CPUOffload(offload_params=True), + param_init_fn=empty_init(), + ) + else: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + cpu_offload=CPUOffload(offload_params=True), + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + pp_style=args.pp_style, + zero_stage=args.zero, + num_model_chunks=args.num_model_chunks, + enable_all_optimization=True, + num_microbatches=args.mbs, + cpu_offload=args.cpu_offload, + precision="bf16", + ) + elif args.plugin == "3d_cpu": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + cpu_offload=True, + enable_fused_normalization=torch.cuda.is_available(), + num_microbatches=args.mbs, + initial_scale=2**8, + precision="bf16", + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Dataset and Dataloader + # ============================== + dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size + + config = MODEL_CONFIGS[args.config] + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + + # ============================== + # Initialize Model and Optimizer + # ============================== + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) + + with init_ctx: + model = GPT2LMHeadModel(config) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, + model.config.n_layer, + model.config.n_embd, + model.config.vocab_size, + args.grad_checkpoint, + args.ignore_steps, + dp_world_size=dp_size, + ) + + optimizer = Adam(model.parameters()) + torch.set_default_dtype(torch.bfloat16) + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + torch.set_default_dtype(torch.float) + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + booster.execute_pipeline( + data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False + ) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(**batch) + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + performance_evaluator.on_fit_end() + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index 54b023f64..832465490 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -19,6 +19,9 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from examples.language.data_utils import RandomDataset +from examples.language.model_utils import format_numel_str, get_model_numel +from examples.language.performance_evaluator import PerformanceEvaluator # ============================== # Constants diff --git a/examples/language/llama2/model_utils.py b/examples/language/model_utils.py similarity index 100% rename from examples/language/llama2/model_utils.py rename to examples/language/model_utils.py diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/performance_evaluator.py similarity index 100% rename from examples/language/llama2/performance_evaluator.py rename to examples/language/performance_evaluator.py diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index a16b16ad6..fce81ab52 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -102,4 +102,4 @@ class ModelZooRegistry(dict): return new_dict -model_zoo = ModelZooRegistry() +model_zoo = ModelZooRegistry() \ No newline at end of file diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index d629e769d..285c4866c 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -276,4 +276,4 @@ def test_gemini_plugin(early_stop: bool = True): if __name__ == "__main__": - test_gemini_plugin(early_stop=False) + test_gemini_plugin(early_stop=False) \ No newline at end of file diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 17dfa3a18..0f72d2bcd 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -185,4 +185,4 @@ def test_gemini_plugin_3d(early_stop: bool = True): if __name__ == "__main__": - test_gemini_plugin(early_stop=False) + test_gemini_plugin(early_stop=False) \ No newline at end of file diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 61cac1d83..daddf6dc7 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -186,4 +186,4 @@ def test_gemini_ckpIO_3d(): if __name__ == "__main__": - test_gemini_ckpIO() + test_gemini_ckpIO() \ No newline at end of file diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index d0c4cd0a7..aeca5f21d 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -24,4 +24,4 @@ def test_torchvision_models_lazy_init(subset, default_device): if __name__ == "__main__": - test_torchvision_models_lazy_init("transformers", "cpu") + test_torchvision_models_lazy_init("transformers", "cpu") \ No newline at end of file diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 10ffdcd71..e056860ed 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -1,3 +1,4 @@ +import os from contextlib import nullcontext import torch @@ -11,8 +12,10 @@ from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLin from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn - # This code is copied from https://github.com/huggingface/transformers +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + class Conv1D(nn.Module): """ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 5bacf1865..defa4afb9 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -1,3 +1,4 @@ +import os from contextlib import nullcontext import torch @@ -11,6 +12,8 @@ from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index b02d58181..5e996d2ba 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -1,3 +1,4 @@ +import os from contextlib import nullcontext import torch @@ -11,8 +12,10 @@ from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLin from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn - # This code is copied from https://github.com/huggingface/transformers +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + class Conv1D(nn.Module): """ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). From 822241a99cca799e1fca250ff2fb7f54ea0f8dcd Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Tue, 5 Mar 2024 12:08:58 +0800 Subject: [PATCH 02/43] [doc] sora release (#5425) * [doc] sora release * [doc] sora release * [doc] sora release * [doc] sora release --- README.md | 20 ++++++++++++++++---- applications/README.md | 1 + docs/README-zh-Hans.md | 20 ++++++++++++++++---- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 442e6bbcd..f045c5604 100644 --- a/README.md +++ b/README.md @@ -25,16 +25,13 @@ ## Latest News +* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) * [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b) * [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient) * [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific LLM Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) * [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training) * [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth) -* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining) -* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) -* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana) -* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs) ## Table of Contents
    @@ -43,6 +40,7 @@
  • Colossal-AI for Real World Applications
      +
    • Open-Sora: Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million
    • Colossal-LLaMA-2: One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution
    • ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline
    • AIGC: Acceleration of Stable Diffusion
    • @@ -126,6 +124,20 @@ distributed training and inference in a few lines.

      (back to top)

      ## Colossal-AI in the Real World +### Open-Sora + +[Open-Sora](https://github.com/hpcaitech/Open-Sora):Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million +[[code]](https://github.com/hpcaitech/Open-Sora) +[[blog]](https://hpc-ai.com/blog/open-sora) + +

      + +

      + +

      + +

      + ### Colossal-LLaMA-2 diff --git a/applications/README.md b/applications/README.md index 49a2900f1..8abe1e52d 100644 --- a/applications/README.md +++ b/applications/README.md @@ -4,6 +4,7 @@ This directory contains the applications that are powered by Colossal-AI. The list of applications include: +- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million - [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2. - [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs. - [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF. diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index c25f19795..90ad5540a 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -24,16 +24,13 @@ ## 新闻 +* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) * [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b) * [2023/11] [Enhanced MoE Parallelism, Open-source MoE Model Training Can Be 9 Times More Efficient](https://www.hpc-ai.tech/blog/enhanced-moe-parallelism-open-source-moe-model-training-can-be-9-times-more-efficient) * [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific LLM Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution) * [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training) * [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth) -* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining) -* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) -* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana) -* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs) ## 目录
        @@ -42,6 +39,7 @@
      • Colossal-AI 成功案例
          +
        • Open-Sora:开源Sora复现方案,成本降低46%,序列扩充至近百万
        • Colossal-LLaMA-2: 千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2
        • ColossalChat:完整RLHF流程0门槛克隆ChatGPT
        • AIGC: 加速 Stable Diffusion
        • @@ -121,6 +119,20 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

          (返回顶端)

          ## Colossal-AI 成功案例 +### Open-Sora + +[Open-Sora](https://github.com/hpcaitech/Open-Sora):开源Sora复现方案,成本降低46%,序列扩充至近百万 +[[代码]](https://github.com/hpcaitech/Open-Sora) +[[博客]](https://hpc-ai.com/blog/open-sora) + +

          + +

          + +

          + +

          + ### Colossal-LLaMA-2 - 7B:千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2 From 070df689e627d07f28c8087ec85a4299c73145d9 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 5 Mar 2024 15:35:54 +0800 Subject: [PATCH 03/43] [devops] fix extention building (#5427) --- .cuda_ext.json | 12 ++--- .../compatiblity_test_on_dispatch.yml | 2 +- .github/workflows/compatiblity_test_on_pr.yml | 2 +- .../compatiblity_test_on_schedule.yml | 2 +- .../workflows/cuda_ext_check_before_merge.yml | 2 +- .github/workflows/doc_test_on_pr.yml | 2 +- .github/workflows/doc_test_on_schedule.yml | 2 +- .../workflows/example_check_on_dispatch.yml | 2 +- .github/workflows/example_check_on_pr.yml | 2 +- .../workflows/example_check_on_schedule.yml | 2 +- colossalai/cli/check/check_installation.py | 2 +- .../inference/serving/ray_serve/README.md | 2 +- .../inference/serving/torch_serve/README.md | 2 +- .../serving/torch_serve/docker/Dockerfile | 2 +- colossalai/nn/optimizer/cpu_adam.py | 2 +- docker/Dockerfile | 2 +- docs/README-zh-Hans.md | 44 +++++++++---------- docs/source/en/get_started/installation.md | 2 +- examples/images/diffusion/README.md | 2 +- examples/images/diffusion/test_ci.sh | 2 +- examples/language/llama2/README.md | 2 +- examples/language/openmoe/README.md | 2 +- extensions/utils.py | 2 +- setup.py | 2 +- 24 files changed, 50 insertions(+), 50 deletions(-) diff --git a/.cuda_ext.json b/.cuda_ext.json index eba19cf05..b8269f837 100644 --- a/.cuda_ext.json +++ b/.cuda_ext.json @@ -1,16 +1,16 @@ { "build": [ { - "torch_command": "pip install torch==1.12.1+cu102 torchvision==0.13.1+cu102 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu102", - "cuda_image": "hpcaitech/cuda-conda:10.2" + "torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121", + "cuda_image": "hpcaitech/cuda-conda:12.1" }, { - "torch_command": "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113", - "cuda_image": "hpcaitech/cuda-conda:11.3" + "torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118", + "cuda_image": "hpcaitech/cuda-conda:11.8" }, { - "torch_command": "pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116", - "cuda_image": "hpcaitech/cuda-conda:11.6" + "torch_command": "pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1", + "cuda_image": "hpcaitech/cuda-conda:11.7" } ] } diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 508321299..a6f9582ac 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -83,7 +83,7 @@ jobs: fi - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . pip install -r requirements/requirements-test.txt - name: Unit Testing run: | diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index cc17c66f9..ede6c380a 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -78,7 +78,7 @@ jobs: - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . pip install -r requirements/requirements-test.txt - name: Unit Testing run: | diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 158fe751b..1cf456ff6 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -75,7 +75,7 @@ jobs: - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . pip install -r requirements/requirements-test.txt - name: Unit Testing diff --git a/.github/workflows/cuda_ext_check_before_merge.yml b/.github/workflows/cuda_ext_check_before_merge.yml index 686f0f395..14f53bd69 100644 --- a/.github/workflows/cuda_ext_check_before_merge.yml +++ b/.github/workflows/cuda_ext_check_before_merge.yml @@ -51,4 +51,4 @@ jobs: - name: Build run: | - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml index 51238905e..8afc46b87 100644 --- a/.github/workflows/doc_test_on_pr.yml +++ b/.github/workflows/doc_test_on_pr.yml @@ -89,7 +89,7 @@ jobs: - name: Install ColossalAI run: | source activate pytorch - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . - name: Test the Doc run: | diff --git a/.github/workflows/doc_test_on_schedule.yml b/.github/workflows/doc_test_on_schedule.yml index b3536184d..e2491e460 100644 --- a/.github/workflows/doc_test_on_schedule.yml +++ b/.github/workflows/doc_test_on_schedule.yml @@ -32,7 +32,7 @@ jobs: - name: Install ColossalAI run: | - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . - name: Install Doc Test Requirements run: | diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index bba321fd2..24e726b4f 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -53,7 +53,7 @@ jobs: uses: actions/checkout@v3 - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . - name: Test the example run: | dir=${{ matrix.directory }} diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index fcff8e569..728f059c1 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -88,7 +88,7 @@ jobs: - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . - name: Test the example run: | diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index abb947949..efb131a86 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -42,7 +42,7 @@ jobs: - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v . - name: Traverse all files run: | diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py index 772c513ff..f5602bbe6 100644 --- a/colossalai/cli/check/check_installation.py +++ b/colossalai/cli/check/check_installation.py @@ -76,7 +76,7 @@ def check_installation(): click.echo("") click.echo(f"Note:") click.echo( - f"1. AOT (ahead-of-time) compilation of the CUDA kernels occurs during installation when the environment variable CUDA_EXT=1 is set" + f"1. AOT (ahead-of-time) compilation of the CUDA kernels occurs during installation when the environment variable BUILD_EXT=1 is set" ) click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime") diff --git a/colossalai/legacy/inference/serving/ray_serve/README.md b/colossalai/legacy/inference/serving/ray_serve/README.md index 1d4082387..888f04bb5 100644 --- a/colossalai/legacy/inference/serving/ray_serve/README.md +++ b/colossalai/legacy/inference/serving/ray_serve/README.md @@ -25,7 +25,7 @@ conda install -c conda-forge cupy cudnn cutensor nccl cuda-version=11.6 # install colossalai with PyTorch extensions cd -CUDA_EXT=1 pip install -e . +BUILD_EXT=1 pip install -e . # install other dependencies pip install triton==2.0.0.dev20221202 diff --git a/colossalai/legacy/inference/serving/torch_serve/README.md b/colossalai/legacy/inference/serving/torch_serve/README.md index 6bd145bc3..fcf2e36d2 100644 --- a/colossalai/legacy/inference/serving/torch_serve/README.md +++ b/colossalai/legacy/inference/serving/torch_serve/README.md @@ -25,7 +25,7 @@ conda install -c "nvidia/label/cuda-11.6.2" cuda-toolkit cd pip install -r requirements/requirements.txt pip install -r requirements/requirements-test.txt -CUDA_EXT=1 pip install -e . +BUILD_EXT=1 pip install -e . # install torchserve cd diff --git a/colossalai/legacy/inference/serving/torch_serve/docker/Dockerfile b/colossalai/legacy/inference/serving/torch_serve/docker/Dockerfile index 6d780a847..755812397 100644 --- a/colossalai/legacy/inference/serving/torch_serve/docker/Dockerfile +++ b/colossalai/legacy/inference/serving/torch_serve/docker/Dockerfile @@ -38,7 +38,7 @@ ARG VERSION=main RUN git clone -b ${VERSION} https://github.com/hpcaitech/ColossalAI.git && \ cd ./ColossalAI && \ git checkout 3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0 && \ - CUDA_EXT=1 pip install -v --no-cache-dir . + BUILD_EXT=1 pip install -v --no-cache-dir . # install titans RUN pip install --no-cache-dir titans diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 5be629fb2..68fb582e5 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -78,7 +78,7 @@ class CPUAdam(NVMeOptimizer): super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode cpu_adam = CPUAdamLoader().load() - # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification + # if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) def torch_adam_update( diff --git a/docker/Dockerfile b/docker/Dockerfile index 26d3fab1b..0e796a9d4 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -37,7 +37,7 @@ RUN git clone https://github.com/NVIDIA/apex && \ ARG VERSION=main RUN git clone -b ${VERSION} https://github.com/hpcaitech/ColossalAI.git \ && cd ./ColossalAI \ - && CUDA_EXT=1 pip install -v --no-cache-dir . + && BUILD_EXT=1 pip install -v --no-cache-dir . # install titans RUN pip install --no-cache-dir titans diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 90ad5540a..bc4106d12 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -146,25 +146,25 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 [[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-13b-base) [[Modelscope model weights]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-13b-base/summary) -| Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot)| AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) | -| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :--------------: | :-------------: | :-------------: | -| Baichuan-7B | - | 1.2T | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | -| Baichuan-13B-Base | - | 1.4T | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | -| Baichuan2-7B-Base | - | 2.6T | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | -| Baichuan2-13B-Base | - | 2.6T | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | -| ChatGLM-6B | - | 1.0T | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | -| ChatGLM2-6B | - | 1.4T | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | -| InternLM-7B | - | 1.6T | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | -| Qwen-7B | - | 2.2T | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | -| Llama-2-7B | - | 2.0T | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | -| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | 37.43 | 29.92 | 32.00 | 27.57 | - | -| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | 38.56 | 31.52 | 30.99 | 25.95 | - | -| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | -| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | 43.73 | 42.04 | 37.64 | 30.61 | - | -| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | 48.41 | 38.31 | 38.45 | 27.72 | - | -| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | 49.96 | 41.10 | 39.83 | 33.00 | - | -| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | 50.25 | 40.99 | 40.04 | 30.54 | - | -| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | +| Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot) | AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) | +|:------------------------------:|:----------:|:---------------:|:-------------:|:--------------:|:----------------:|:---------------:|:--------------:| +| Baichuan-7B | - | 1.2T | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 | +| Baichuan-13B-Base | - | 1.4T | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 | +| Baichuan2-7B-Base | - | 2.6T | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 | +| Baichuan2-13B-Base | - | 2.6T | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 | +| ChatGLM-6B | - | 1.0T | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 | +| ChatGLM2-6B | - | 1.4T | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 | +| InternLM-7B | - | 1.6T | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 | +| Qwen-7B | - | 2.2T | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 | +| Llama-2-7B | - | 2.0T | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - | +| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | 37.43 | 29.92 | 32.00 | 27.57 | - | +| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | 38.56 | 31.52 | 30.99 | 25.95 | - | +| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 | +| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | 43.73 | 42.04 | 37.64 | 30.61 | - | +| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | 48.41 | 38.31 | 38.45 | 27.72 | - | +| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | 49.96 | 41.10 | 39.83 | 33.00 | - | +| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | 50.25 | 40.99 | 40.04 | 30.54 | - | +| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 | ### ColossalChat @@ -406,10 +406,10 @@ pip install colossalai **注:目前只支持Linux。** -但是,如果你想在安装时就直接构建PyTorch扩展,您可以设置环境变量`CUDA_EXT=1`. +但是,如果你想在安装时就直接构建PyTorch扩展,您可以设置环境变量`BUILD_EXT=1`. ```bash -CUDA_EXT=1 pip install colossalai +BUILD_EXT=1 pip install colossalai ``` **否则,PyTorch扩展只会在你实际需要使用他们时在运行时里被构建。** @@ -438,7 +438,7 @@ pip install . 我们默认在`pip install`时不安装PyTorch扩展,而是在运行时临时编译,如果你想要提前安装这些扩展的话(在使用融合优化器时会用到),可以使用一下命令。 ```shell -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ```

          (返回顶端)

          diff --git a/docs/source/en/get_started/installation.md b/docs/source/en/get_started/installation.md index f9c8fe475..50325462d 100644 --- a/docs/source/en/get_started/installation.md +++ b/docs/source/en/get_started/installation.md @@ -42,7 +42,7 @@ pip install -r requirements/requirements.txt BUILD_EXT=1 pip install . ``` -If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `CUDA_EXT`: +If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `BUILD_EXT`: ```shell pip install . diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index d6a1c47d6..5434551f4 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -77,7 +77,7 @@ git clone https://github.com/hpcaitech/ColossalAI.git cd ColossalAI # install colossalai -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` #### Step 3: Accelerate with flash attention by xformers (Optional) diff --git a/examples/images/diffusion/test_ci.sh b/examples/images/diffusion/test_ci.sh index 44cf47046..652db5d39 100755 --- a/examples/images/diffusion/test_ci.sh +++ b/examples/images/diffusion/test_ci.sh @@ -8,7 +8,7 @@ conda activate ldm conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch pip install transformers diffusers invisible-watermark -CUDA_EXT=1 pip install colossalai +BUILD_EXT=1 pip install colossalai wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt diff --git a/examples/language/llama2/README.md b/examples/language/llama2/README.md index 752453b5a..068f15cbb 100644 --- a/examples/language/llama2/README.md +++ b/examples/language/llama2/README.md @@ -53,7 +53,7 @@ We follow the hyperparameter settings from the original LLaMA paper. We use Adam Please install the latest ColossalAI from source. ```bash -CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI +BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI ``` Then install other dependencies. diff --git a/examples/language/openmoe/README.md b/examples/language/openmoe/README.md index 45657f192..f62223c93 100644 --- a/examples/language/openmoe/README.md +++ b/examples/language/openmoe/README.md @@ -17,7 +17,7 @@ Please install the latest ColossalAI from source. ```bash -CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI +BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI ``` Then install dependencies. diff --git a/extensions/utils.py b/extensions/utils.py index 3f75f952d..d5d87a77a 100644 --- a/extensions/utils.py +++ b/extensions/utils.py @@ -154,7 +154,7 @@ def check_cuda_availability(): def set_cuda_arch_list(cuda_dir): """ This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation. - Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'. + Ahead-of-time compilation occurs when BUILD_EXT=1 is set when running 'pip install'. """ cuda_available = check_cuda_availability() diff --git a/setup.py b/setup.py index e54ec41ea..ef89481e6 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ def get_version() -> str: if BUILD_EXT: if not TORCH_AVAILABLE: raise ModuleNotFoundError( - "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + "[extension] PyTorch is not found while BUILD_EXT=1. You need to install PyTorch first in order to build CUDA extensions" ) from extensions import ALL_EXTENSIONS From e304e4db354906493f9e22866b8fcee5f403a829 Mon Sep 17 00:00:00 2001 From: MickeyCHAN <76671016+danyow-cheung@users.noreply.github.com> Date: Tue, 5 Mar 2024 21:41:23 +0800 Subject: [PATCH 04/43] [hotfix] fix sd vit import error (#5420) * fix import error * Update dpt_depth.py --------- Co-authored-by: binmakeswell --- examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py index 74871e8b1..fbb0e0563 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -2,7 +2,8 @@ import torch import torch.nn as nn from .base_model import BaseModel -from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder, forward_vit +from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder +from .vit import forward_vit def _make_fusion_block(features, use_bn): From e239cf9060d90a22557d685c9d6ce350520c5337 Mon Sep 17 00:00:00 2001 From: Luo Yihang Date: Tue, 5 Mar 2024 21:44:38 +0800 Subject: [PATCH 05/43] [hotfix] fix typo of openmoe model source (#5403) --- examples/language/openmoe/benchmark/benchmark_cai.py | 2 +- examples/language/openmoe/benchmark/benchmark_fsdp.py | 2 +- examples/language/openmoe/infer.py | 6 +++--- examples/language/openmoe/train.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 03b660ecf..770c500d8 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -207,7 +207,7 @@ def main(): coordinator.print_on_master(f"Set plugin as {plugin}") # Build OpenMoe model - repo_name = "hpcaitech/openmoe-" + args.model_name + repo_name = "hpcai-tech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) set_openmoe_args( config, diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index 7f438fc5a..b00fbd001 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -53,7 +53,7 @@ def fsdp_main(rank, world_size, args): train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) torch.cuda.set_device(rank) - config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name) + config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-%s" % args.model_name) set_openmoe_args( config, num_experts=config.num_experts, diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index db90c6e34..04df64531 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -15,19 +15,19 @@ def parse_args(): def inference(args): tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") if args.model == "test": - config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base") set_openmoe_args(config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=True) model = OpenMoeForCausalLM(config) else: - config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}") + config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}") set_openmoe_args(config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=False) - model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config) + model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config) model = model.eval().bfloat16() model = model.to(torch.cuda.current_device()) diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 1ae661f54..89c4d5420 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -269,12 +269,12 @@ def main(): # Build OpenMoe model if test_mode: - config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base") config.hidden_size = 128 config.intermediate_size = 256 config.vocab_size = 32000 else: - repo_name = "hpcaitech/openmoe-" + args.model_name + repo_name = "hpcai-tech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) set_openmoe_args( config, From 70cce5cbed8046f82389f76f0916fc060aef9243 Mon Sep 17 00:00:00 2001 From: digger yu Date: Tue, 5 Mar 2024 21:45:55 +0800 Subject: [PATCH 06/43] [doc] update some translations with README-zh-Hans.md (#5382) --- docs/README-zh-Hans.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index bc4106d12..110e1a502 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -141,10 +141,10 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 [[模型权重]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base) - 13B: 万元预算打造高质量13B私有模型 -[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) -[[blog]](https://hpc-ai.com/blog/colossal-llama-2-13b) -[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-13b-base) -[[Modelscope model weights]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-13b-base/summary) +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2) +[[博客]](https://hpc-ai.com/blog/colossal-llama-2-13b) +[[HuggingFace 模型权重]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-13b-base) +[[Modelscope 模型权重]](https://www.modelscope.cn/models/colossalai/Colossal-LLaMA-2-13b-base/summary) | Model | Backbone | Tokens Consumed | MMLU (5-shot) | CMMLU (5-shot) | AGIEval (5-shot) | GAOKAO (0-shot) | CEval (5-shot) | |:------------------------------:|:----------:|:---------------:|:-------------:|:--------------:|:----------------:|:---------------:|:--------------:| @@ -264,8 +264,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

          - 700亿参数LLaMA2训练加速195% -[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) -[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training) +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2) +[[博客]](https://www.hpc-ai.tech/blog/70b-llama2-training) ### LLaMA1

          @@ -361,7 +361,8 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的

          -- [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): Inference performance improved by 46%, open source solution breaks the length limit of LLM for multi-round conversations +- [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): 开源解决方案打破了多轮对话的 LLM 长度限制,推理性能提高了46% +

          From 16c96d4d8cbe26b5ee32a35fd5ee809e035c9e96 Mon Sep 17 00:00:00 2001 From: digger yu Date: Tue, 5 Mar 2024 21:47:48 +0800 Subject: [PATCH 07/43] [hotfix] fix typo change _descrption to _description (#5331) --- .../inference/engine/microbatch_manager.py | 54 +++++++++---------- .../inference/pipeline/microbatch_manager.py | 54 +++++++++---------- colossalai/pipeline/schedule/generate.py | 12 ++--- 3 files changed, 60 insertions(+), 60 deletions(-) diff --git a/colossalai/inference/engine/microbatch_manager.py b/colossalai/inference/engine/microbatch_manager.py index d698c89f9..7264b81e0 100644 --- a/colossalai/inference/engine/microbatch_manager.py +++ b/colossalai/inference/engine/microbatch_manager.py @@ -17,8 +17,8 @@ class Status(Enum): class MicroBatchDescription: """ - This is the class to record the infomation of each microbatch, and also do some update operation. - This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more + This is the class to record the information of each microbatch, and also do some update operation. + This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more details, please refer to the doc of these two classes blow. Args: @@ -61,15 +61,15 @@ class MicroBatchDescription: @property def cur_length(self): """ - Return the current sequnence length of micro batch + Return the current sequence length of micro batch """ class HeadMicroBatchDescription(MicroBatchDescription): """ - This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` - and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the + This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` + and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the information and the condition to determine the state is different from other stages. Args: @@ -123,7 +123,7 @@ class HeadMicroBatchDescription(MicroBatchDescription): class BodyMicroBatchDescription(MicroBatchDescription): """ - This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, + This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, Args: inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. @@ -173,76 +173,76 @@ class MicroBatchManager: self.max_input_len = max_input_len self.max_output_len = max_output_len self.cache_manager_list = cache_manager_list - self.mb_descrption_buffer = {} + self.mb_description_buffer = {} self.new_tokens_buffer = {} self.idx = 0 - def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): + def add_description(self, inputs_dict: Dict[str, torch.Tensor]): if self.stage == 0: - self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( + self.mb_description_buffer[self.idx] = HeadMicroBatchDescription( inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) else: - self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( + self.mb_description_buffer[self.idx] = BodyMicroBatchDescription( inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) def step(self, new_token: torch.Tensor = None): """ Update the state if microbatch manager, 2 conditions. - 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. - 2. For other conditon, only receive the output of previous stage, and update the descrption. + 1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs. + 2. For other condition, only receive the output of previous stage, and update the description. Args: inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. new_token (torch.Tensor): the new token generated by current stage. """ - # Add descrption first if the descrption is None - self.cur_descrption.update(new_token) + # Add description first if the description is None + self.cur_description.update(new_token) return self.cur_state def export_new_tokens(self): new_tokens_list = [] - for i in self.mb_descrption_buffer.values(): + for i in self.mb_description_buffer.values(): new_tokens_list.extend(i.new_tokens.tolist()) return new_tokens_list def is_micro_batch_done(self): - if len(self.mb_descrption_buffer) == 0: + if len(self.mb_description_buffer) == 0: return False - for mb in self.mb_descrption_buffer.values(): + for mb in self.mb_description_buffer.values(): if mb.state != Status.DONE: return False return True def clear(self): - self.mb_descrption_buffer.clear() + self.mb_description_buffer.clear() for cache in self.cache_manager_list: cache.free_all() def next(self): self.idx = (self.idx + 1) % self.buffer_size - def _remove_descrption(self): - self.mb_descrption_buffer.pop(self.idx) + def _remove_description(self): + self.mb_description_buffer.pop(self.idx) @property - def cur_descrption(self) -> MicroBatchDescription: - return self.mb_descrption_buffer.get(self.idx) + def cur_description(self) -> MicroBatchDescription: + return self.mb_description_buffer.get(self.idx) @property def cur_infer_state(self): - if self.cur_descrption is None: + if self.cur_description is None: return None - return self.cur_descrption.infer_state + return self.cur_description.infer_state @property def cur_state(self): """ - Return the state of current micro batch, when current descrption is None, the state is PREFILL + Return the state of current micro batch, when current description is None, the state is PREFILL """ - if self.cur_descrption is None: + if self.cur_description is None: return Status.PREFILL - return self.cur_descrption.state + return self.cur_description.state diff --git a/colossalai/legacy/inference/pipeline/microbatch_manager.py b/colossalai/legacy/inference/pipeline/microbatch_manager.py index 441cf6039..cb0a8c1a9 100644 --- a/colossalai/legacy/inference/pipeline/microbatch_manager.py +++ b/colossalai/legacy/inference/pipeline/microbatch_manager.py @@ -18,8 +18,8 @@ class Status(Enum): class MicroBatchDescription: """ - This is the class to record the infomation of each microbatch, and also do some update operation. - This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more + This is the class to record the information of each microbatch, and also do some update operation. + This class is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more details, please refer to the doc of these two classes blow. Args: @@ -62,15 +62,15 @@ class MicroBatchDescription: @property def cur_length(self): """ - Return the current sequnence length of micro batch + Return the current sequence length of micro batch """ class HeadMicroBatchDescription(MicroBatchDescription): """ - This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` - and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the + This class is used to record the information of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` + and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schedule of pipeline, the operation to update the information and the condition to determine the state is different from other stages. Args: @@ -124,7 +124,7 @@ class HeadMicroBatchDescription(MicroBatchDescription): class BodyMicroBatchDescription(MicroBatchDescription): """ - This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, + This class is used to record the information of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, Args: inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. @@ -174,76 +174,76 @@ class MicroBatchManager: self.max_input_len = max_input_len self.max_output_len = max_output_len self.cache_manager_list = cache_manager_list - self.mb_descrption_buffer = {} + self.mb_description_buffer = {} self.new_tokens_buffer = {} self.idx = 0 - def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): + def add_description(self, inputs_dict: Dict[str, torch.Tensor]): if self.stage == 0: - self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( + self.mb_description_buffer[self.idx] = HeadMicroBatchDescription( inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) else: - self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( + self.mb_description_buffer[self.idx] = BodyMicroBatchDescription( inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) def step(self, new_token: torch.Tensor = None): """ Update the state if microbatch manager, 2 conditions. - 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. - 2. For other conditon, only receive the output of previous stage, and update the descrption. + 1. For first stage in PREFILL, receive inputs and outputs, `_add_description` will save its inputs. + 2. For other condition, only receive the output of previous stage, and update the description. Args: inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. new_token (torch.Tensor): the new token generated by current stage. """ - # Add descrption first if the descrption is None - self.cur_descrption.update(new_token) + # Add description first if the description is None + self.cur_description.update(new_token) return self.cur_state def export_new_tokens(self): new_tokens_list = [] - for i in self.mb_descrption_buffer.values(): + for i in self.mb_description_buffer.values(): new_tokens_list.extend(i.new_tokens.tolist()) return new_tokens_list def is_micro_batch_done(self): - if len(self.mb_descrption_buffer) == 0: + if len(self.mb_description_buffer) == 0: return False - for mb in self.mb_descrption_buffer.values(): + for mb in self.mb_description_buffer.values(): if mb.state != Status.DONE: return False return True def clear(self): - self.mb_descrption_buffer.clear() + self.mb_description_buffer.clear() for cache in self.cache_manager_list: cache.free_all() def next(self): self.idx = (self.idx + 1) % self.buffer_size - def _remove_descrption(self): - self.mb_descrption_buffer.pop(self.idx) + def _remove_description(self): + self.mb_description_buffer.pop(self.idx) @property - def cur_descrption(self) -> MicroBatchDescription: - return self.mb_descrption_buffer.get(self.idx) + def cur_description(self) -> MicroBatchDescription: + return self.mb_description_buffer.get(self.idx) @property def cur_infer_state(self): - if self.cur_descrption is None: + if self.cur_description is None: return None - return self.cur_descrption.infer_state + return self.cur_description.infer_state @property def cur_state(self): """ - Return the state of current micro batch, when current descrption is None, the state is PREFILL + Return the state of current micro batch, when current description is None, the state is PREFILL """ - if self.cur_descrption is None: + if self.cur_description is None: return Status.PREFILL - return self.cur_descrption.state + return self.cur_description.state diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 20f316c2a..d6a6aec63 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -95,7 +95,7 @@ class GenerateSchedule(PipelineSchedule): Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` """ - model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state} + model_inputs = {"infer_state": self.mb_manager.cur_description.infer_state} return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): @@ -107,7 +107,7 @@ class GenerateSchedule(PipelineSchedule): Returns: dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` """ - new_mask = self.mb_manager.cur_descrption.attn_mask + new_mask = self.mb_manager.cur_description.attn_mask return dict(input_ids=new_token, attention_mask=new_mask) @@ -133,7 +133,7 @@ class GenerateSchedule(PipelineSchedule): 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state """ inputs_dict = self.load_micro_batch() - self.mb_manager.add_descrption(inputs_dict) + self.mb_manager.add_description(inputs_dict) def _load_stage_action(self, model: Module) -> None: """ @@ -141,7 +141,7 @@ class GenerateSchedule(PipelineSchedule): 1.load micro_batch 2.do the forward 3.step to update """ inputs_dict = self.load_micro_batch() - self.mb_manager.add_descrption(inputs_dict) + self.mb_manager.add_description(inputs_dict) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) @@ -379,7 +379,7 @@ class GenerateSchedule(PipelineSchedule): if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - self.mb_manager.add_descrption(inputs_dict) + self.mb_manager.add_description(inputs_dict) interval_inputs = {"infer_state": self.mb_manager.cur_infer_state} output_dict = model_forward(model, inputs_dict, interval_inputs) # In GENERATE phase @@ -415,7 +415,7 @@ class GenerateSchedule(PipelineSchedule): inputs_dict = None if self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() - self.mb_manager.add_descrption(inputs_dict) + self.mb_manager.add_description(inputs_dict) interval_inputs = { "hidden_states": hidden_states["hidden_states"], "infer_state": self.mb_manager.cur_infer_state, From 049121d19d7ead4a4dcbeb091df9ff87ba991a63 Mon Sep 17 00:00:00 2001 From: digger yu Date: Tue, 5 Mar 2024 21:48:46 +0800 Subject: [PATCH 08/43] [hotfix] fix typo change enabel to enable under colossalai/shardformer/ (#5317) --- colossalai/shardformer/layer/_operation.py | 4 ++-- colossalai/shardformer/layer/normalization.py | 8 ++++---- colossalai/shardformer/modeling/bloom.py | 2 +- .../shardformer/modeling/chatglm2_6b/modeling_chatglm.py | 6 +++--- colossalai/shardformer/modeling/gptj.py | 4 ++-- colossalai/shardformer/modeling/llama.py | 4 ++-- colossalai/shardformer/modeling/opt.py | 2 +- colossalai/shardformer/modeling/t5.py | 2 +- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index d4960c7e4..241770901 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -173,7 +173,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. - overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + overlap (`bool`): Whether to overlap the all_gather op and gradient calculate in backward. """ @@ -534,7 +534,7 @@ class HookParameter(torch.autograd.Function): return grad_output, None, None -def hook_paramter_in_backward(input, weight=None, bias=None): +def hook_parameter_in_backward(input, weight=None, bias=None): return HookParameter.apply(input, weight, bias) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 4aa281290..43dd153af 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -7,7 +7,7 @@ import torch.nn as nn from colossalai.lazy import LazyInitContext -from ._operation import hook_paramter_in_backward +from ._operation import hook_parameter_in_backward from .utils import SeqParallelUtils __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] @@ -29,7 +29,7 @@ try: def forward(self, input): output = super().forward(input) - output = hook_paramter_in_backward(output, self.weight, self.bias) + output = hook_parameter_in_backward(output, self.weight, self.bias) return output class FusedRMSNormWithHook(ApexFusedRMSNorm): @@ -38,7 +38,7 @@ try: def forward(self, input): output = super().forward(input) - output = hook_paramter_in_backward(output, self.weight) + output = hook_parameter_in_backward(output, self.weight) return output except ImportError: @@ -79,7 +79,7 @@ if EnableFastLayerNorm: def forward(self, input): output = super().forward(input) - output = hook_paramter_in_backward(output, self.weight, self.bias) + output = hook_parameter_in_backward(output, self.weight, self.bias) return output diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index cd8a02330..d94c30d29 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -699,7 +699,7 @@ class BloomPipelineForwards: return {"hidden_states": hidden_states} -def get_bloom_flash_attention_forward(enabel_jit_fused=False): +def get_bloom_flash_attention_forward(enable_jit_fused=False): try: from xformers.ops import memory_efficient_attention as me_attention except: diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index 71aa2296e..bf581300a 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -181,7 +181,7 @@ class RotaryEmbedding(nn.Module): cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - # this is to mimic the behaviour of complex32, else we will get different results + # this is to mimic the behavior of complex32, else we will get different results if dtype in (torch.float16, torch.bfloat16, torch.int8): cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() return cache @@ -290,7 +290,7 @@ class CoreAttention(torch.nn.Module): # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # preallocting input tensor: [b * np, sq, sk] + # preallocating input tensor: [b * np, sq, sk] matmul_input_buffer = torch.empty( output_size[0] * output_size[1], output_size[2], @@ -1289,7 +1289,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + "This behavior is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 22b0f7a90..1990d7df3 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -122,7 +122,7 @@ class GPTJPipelineForwards: # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - # position id to be asssigned not just for the first stage for attn input + # position id to be assigned not just for the first stage for attn input if position_ids is not None: position_ids = position_ids.view(-1, seq_length) else: @@ -593,7 +593,7 @@ def get_gptj_flash_attention_forward(): # key = key.permute(0, 2, 1, 3) # query = query.permute(0, 2, 1, 3) - key = key.to(dtype=value.dtype) # fp16 compatability + key = key.to(dtype=value.dtype) # fp16 compatibility query = query.to(dtype=value.dtype) if layer_past is not None: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 92c709218..f20ceb2d6 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -225,13 +225,13 @@ class LlamaPipelineForwards: >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 7f6cbbbcf..d0e267eac 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -123,7 +123,7 @@ class OPTPipelineForwards: else: if hidden_states is None: - raise ValueError("hidden_states shouln't be None for intermediate stages.") + raise ValueError("hidden_states shouldn't be None for intermediate stages.") input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index dcb178520..9c5ce3fb6 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -77,7 +77,7 @@ class T5PipelineForwards: if in_decoder != (stage >= decoder_starting_stage): raise ValueError("Config in T5Stack is not aligned with pipeline setting.") - # at_first_stage: current stage is the first stage of encoder/decoder, taking input_ids/input_embedds + # at_first_stage: current stage is the first stage of encoder/decoder, taking input_ids/input_embeds # at_last_stage: current stage is the last stage of encoder/decoder, making outputs the same form as huggingface at_first_stage = (stage == 0) or (stage == decoder_starting_stage) at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1) From a7ae2b5b4c05f194fc563c488ed885f4a225fbc8 Mon Sep 17 00:00:00 2001 From: Dongruixuan Li Date: Tue, 5 Mar 2024 08:48:55 -0500 Subject: [PATCH 09/43] [eval-hotfix] set few_shot_data to None when few shot is disabled (#5422) --- applications/ColossalEval/colossal_eval/dataset/agieval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index d018a2ba5..32f8544e9 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -201,7 +201,7 @@ class AGIEvalDataset(BaseDataset): for file in files: dataset_name = os.path.basename(file)[0 : -len(".jsonl")] - few_shot_data = [] + few_shot_data = None if few_shot: # process demo once if it is few-shot-CoT few_shot_data = combine_prompt(prompt_path, dataset_name, load_explanation=False, chat_mode=False) From 5e1c93d732a3ee77657fe68c031dacdeae4d7807 Mon Sep 17 00:00:00 2001 From: digger yu Date: Tue, 5 Mar 2024 21:52:30 +0800 Subject: [PATCH 10/43] [hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335) Co-authored-by: binmakeswell --- colossalai/accelerator/api.py | 2 +- colossalai/booster/plugin/gemini_plugin.py | 2 +- colossalai/booster/plugin/hybrid_parallel_plugin.py | 2 +- colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 7 ++++--- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 ++-- colossalai/moe/__init__.py | 4 ++-- colossalai/moe/checkpoint.py | 4 ++-- 7 files changed, 13 insertions(+), 12 deletions(-) diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py index 02b3055d7..85a19b6a6 100644 --- a/colossalai/accelerator/api.py +++ b/colossalai/accelerator/api.py @@ -40,7 +40,7 @@ def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: def auto_set_accelerator() -> None: """ Automatically check if any accelerator is available. - If an accelerator is availabe, set it as the global accelerator. + If an accelerator is available, set it as the global accelerator. """ global _ACCELERATOR diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 95b96bbfd..6c5033773 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -437,7 +437,7 @@ class GeminiPlugin(DPPluginBase): ) def __del__(self): - """Destroy the prcess groups in ProcessGroupMesh""" + """Destroy the process groups in ProcessGroupMesh""" self.pg_mesh.destroy_mesh_process_groups() def support_no_sync(self) -> bool: diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bf677e052..8cc76dd3e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1067,7 +1067,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.max_norm = max_norm def __del__(self): - """Destroy the prcess groups in ProcessGroupMesh""" + """Destroy the process groups in ProcessGroupMesh""" self.pg_mesh.destroy_mesh_process_groups() @property diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 45e5a23c1..454710fcc 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ) from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MOE_MANAGER, MoECheckpintIO +from colossalai.moe import MOE_MANAGER, MoECheckpointIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig @@ -341,9 +341,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): **_kwargs, ) - def get_checkpoint_io(self) -> MoECheckpintIO: + + def get_checkpoint_io(self) -> MoECheckpointIO: if self.checkpoint_io is None: - self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) else: self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 36df30335..808227249 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -51,7 +51,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): pp_group (ProcessGroup): Process group along pipeline parallel dimension. tp_group (ProcessGroup): Process group along tensor parallel dimension. zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. - verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. + verbose (bool, optional): Whether to print logging massage when saving/loading has been successfully executed. Defaults to True. """ def __init__( @@ -574,7 +574,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): # obtain updated param group new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. + new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change. updated_groups.append(new_pg) optimizer.optim.__dict__.update({"param_groups": updated_groups}) diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 6dd0a5fc3..cc33c77f3 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,4 +1,4 @@ -from .checkpoint import MoECheckpintIO +from .checkpoint import MoECheckpointIO from .experts import MLPExperts from .layers import SparseMLP, apply_load_balance from .manager import MOE_MANAGER @@ -14,7 +14,7 @@ __all__ = [ "NormalNoiseGenerator", "UniformNoiseGenerator", "SparseMLP", - "MoECheckpintIO", + "MoECheckpointIO", "MOE_MANAGER", "apply_load_balance", ] diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index b37ffabea..59a0ec3f0 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -40,7 +40,7 @@ from colossalai.tensor.moe_tensor.api import ( ) -class MoECheckpintIO(HybridParallelCheckpointIO): +class MoECheckpointIO(HybridParallelCheckpointIO): def __init__( self, dp_group: ProcessGroup, @@ -373,7 +373,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): # obtain updated param group new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. + new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change. updated_groups.append(new_pg) # ep param group if len(optimizer.optim.param_groups) > len(saved_groups): From c8003d463b7b0aa4d8c3318355ee43840b5f3cda Mon Sep 17 00:00:00 2001 From: hugo-syn <61210734+hugo-syn@users.noreply.github.com> Date: Tue, 5 Mar 2024 15:02:08 +0100 Subject: [PATCH 11/43] [doc] Fix typo s/infered/inferred/ (#5288) Signed-off-by: hugo-syn --- applications/ColossalQA/colossalqa/prompt/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/applications/ColossalQA/colossalqa/prompt/README.md b/applications/ColossalQA/colossalqa/prompt/README.md index e5c74906b..01f51fef7 100644 --- a/applications/ColossalQA/colossalqa/prompt/README.md +++ b/applications/ColossalQA/colossalqa/prompt/README.md @@ -24,7 +24,7 @@ Assistant: ### English ``` [INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. -If the answer cannot be infered based on the given context, please don't share false information.<> +If the answer cannot be inferred based on the given context, please don't share false information.<> Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed. context: @@ -51,7 +51,7 @@ Note that the "[INST] <>...<>" template is the specific prompt format #### Normal Length ``` [INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. -If the answer cannot be infered based on the given context, please don't share false information.<> +If the answer cannot be inferred based on the given context, please don't share false information.<> Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed. context: @@ -73,7 +73,7 @@ Assistant: #### Overlength ``` [INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. -If the answer cannot be infered based on the given context, please don't share false information.<> +If the answer cannot be inferred based on the given context, please don't share false information.<> Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed. context: From 68f55a709c639223c89423c18d55d73806d4f2f5 Mon Sep 17 00:00:00 2001 From: Youngon Date: Tue, 5 Mar 2024 22:03:40 +0800 Subject: [PATCH 12/43] [hotfix] fix stable diffusion inference bug. (#5289) * Update train_ddp.yaml delete "strategy" to fix DDP config loading bug in "main.py" * Update train_ddp.yaml fix inference with scripts/txt2img.py config file load bug. * Update README.md add pretrain model test code. --- examples/images/diffusion/README.md | 2 ++ examples/images/diffusion/configs/train_ddp.yaml | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index 5434551f4..4975dc2e3 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -202,6 +202,8 @@ python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml ``` ## Inference +if you want to test with pretrain model,as bellow: +python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms --outdir ./output --ckpt 512-base-ema.ckpt --config configs/train_ddp.yaml You can get your training last.ckpt and train config.yaml in your `--logdir`, and run by ``` diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml index 72dc05b64..56e1b5705 100644 --- a/examples/images/diffusion/configs/train_ddp.yaml +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -1,5 +1,6 @@ model: base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -86,8 +87,6 @@ lightning: max_epochs: 2 precision: 16 auto_select_gpus: False - strategy: - find_unused_parameters: False log_every_n_steps: 2 # max_steps: 6o logger: True From 743e7fad2f03ab0082d40c5da960aca29daed76e Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:58:56 +0800 Subject: [PATCH 13/43] [colossal-llama2] add stream chat examlple for chat version model (#5428) * add stream chat for chat version * remove os.system clear * modify function name --- .../utils/stream_chat_patch.py | 247 ++++++++++++++++++ .../Colossal-LLaMA-2/stream_chat_example.py | 55 ++++ 2 files changed, 302 insertions(+) create mode 100644 applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py create mode 100644 applications/Colossal-LLaMA-2/stream_chat_example.py diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py new file mode 100644 index 000000000..8f8eecb18 --- /dev/null +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py @@ -0,0 +1,247 @@ +from copy import deepcopy +from typing import Optional, List, Dict, Tuple, Callable, Any + +import torch +from torch import nn + +from transformers import PreTrainedTokenizer +from transformers.utils import logging +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList + +logger = logging.get_logger(__name__) + + +def get_prompt_template( + input_query:str, + history:List[Dict]= None, + roles:list = ["", "Human", "Assistant"], +) -> str: + """ + Generates a prompt template for chat models based on input and history. + + Args: + input_query (str): User's current input query. + history (List[Dict], optional): List of past conversations, each a dict with 'role' and 'message'. + roles (list): Specifies the roles in the conversation, defaults to ["", "Human", "Assistant"]. + + Returns: + str: A formatted prompt including the input query and history. + """ + prompt = "" + if history is None: + new_history = [] + else: + new_history = deepcopy(history) + + new_history.append({"role": roles[1], "message": input_query.strip()}) + new_history.append({"role": roles[2], "message": None}) + + for _, item in enumerate(new_history): + role = item.get("role") + message = item.get("message") + if role == roles[0]: + prompt += f"{message}\n\n" + else: + if message: + prompt += f"{role}: {message}" + else: + prompt += f"{role}: " + return prompt + +@torch.inference_mode() +def streaming_chat( + model: Any, + tokenizer: PreTrainedTokenizer, + input_query: str, + history: List[Dict] = None, + roles: list = ["", "Human", "Assistant"], + past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None, + temperature: float = 0.8, + top_p: float = 0.95, + top_k: int = 50, + do_sample: bool = True, + length_penalty: float = 1.2, + max_new_tokens: int = 512, + logits_processor: LogitsProcessorList = None, + return_past_key_values: bool = False, + **kwargs, +): + """ + Streaming chat responses generation with a given model and tokenizer. + + Args: + model (Any): The language model to generate responses. + tokenizer (PreTrainedTokenizer): Tokenizer compatible with the model, used for encoding inputs and decoding responses. + input_query (str): The current user input to respond to. + history (List[Dict], optional): A list of past conversations, where each conversation is a dictionary with keys 'role' and 'message'. + roles (list): Roles involved in the conversation, defaults to ["", "Human", "Assistant"]. + past_key_values (Tuple[Tuple[torch.FloatTensor, Any], Any], optional): Past key values for incremental decoding. + temperature (float): The temperature value for token sampling, defaults to 0.8. + top_p (float): Nucleus sampling probability threshold, defaults to 0.95. + top_k (int): Top-K filtering threshold, defaults to 50. + do_sample (bool): Whether to sample responses, defaults to True. + length_penalty (float): Penalty for response length, defaults to 1.2. + max_new_tokens (int): Maximum number of new tokens to generate, defaults to 512. + logits_processor (LogitsProcessorList, optional): Custom logits processors, defaults to None. + return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. + **kwargs: Additional keyword arguments for generation. + + Yields: + Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and + optionally the updated past key values if `return_past_key_values` is True. + + Ensures padding is on the left side for the tokenizer. + """ + assert tokenizer.padding_side == "left", "Current generation only supports left padding." + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + + generation_kwargs = { + 'temperature': temperature, + 'top_p': top_p, + 'top_k': top_k, + 'do_sample': do_sample, + 'max_new_tokens': max_new_tokens, + 'length_penalty': length_penalty, + 'use_cache': True, + **kwargs + } + + prompt_str = get_prompt_template(input_query, history=history, roles=roles) + + eos_token_id = [tokenizer.eos_token_id] + inputs = tokenizer(prompt_str, return_tensors="pt").to(model.device) + history.append({"role": roles[1], "message": input_query.strip()}) + history.append({"role": roles[2], "message": None}) + + for outputs in stream_generate(model, **inputs, past_key_values=past_key_values, + eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, + **generation_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + + history[-1]["message"] = response.strip() + if return_past_key_values: + yield response, history, past_key_values + else: + yield response, history + + +@torch.inference_mode() +def stream_generate( + model: Any, + input_ids: torch.Tensor, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values: bool = False, + **kwargs, +): + """ + Generates sequences of token ids using the specified model and generation parameters. + Adapted from https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py + + Args: + model (Any): The model used for generating sequences of token ids. + input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder. + generation_config (Optional[GenerationConfig]): The generation configuration to be used as base parametrization for the generation call. + logits_processor (Optional[LogitsProcessorList]): Custom logits processors that complement the default logits processors built from arguments + and generation config. + stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments + and a generation config. + prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to constrain token generation. + return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. + **kwargs: Additional parameters for model generation. + + Yields: + torch.Tensor: The generated token IDs, updated after each generation step. + Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]: The past key values, returned if `return_past_key_values` is True, defaults to False. + """ + input_ids_len = input_ids.size(1) + + if generation_config is None: + generation_config = model.generation_config + generation_config = deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + + if generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_len + + if input_ids_len >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + # prepare distribution pre_processing samplers + logits_processor = model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_len, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # prepare stopping criteria + stopping_criteria = model._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + + logits_warper = model._get_logits_warper(generation_config) + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + + while True: + model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + # NOTE: this is correct only in left padding mode + # pre-process distribution + next_token_logits = outputs.logits[:, -1, :] + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = model._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break \ No newline at end of file diff --git a/applications/Colossal-LLaMA-2/stream_chat_example.py b/applications/Colossal-LLaMA-2/stream_chat_example.py new file mode 100644 index 000000000..3e45c690f --- /dev/null +++ b/applications/Colossal-LLaMA-2/stream_chat_example.py @@ -0,0 +1,55 @@ +import os +import argparse + +from transformers import AutoTokenizer, AutoModelForCausalLM +from colossal_llama2.utils.stream_chat_patch import streaming_chat + +SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + +def main(args): + model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval() + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + + past_key_values, history = None, [] + roles = ["", "Human", "Assistant"] + + history = [] + history.append({"role": roles[0], "message": SYSTEM}) + + while True: + input_query = input(f"\n{roles[1]}: ") + if input_query.strip() == "exit": + break + if input_query.strip() == "clear": + past_key_values, history = None, [] + continue + + print(f"\n{roles[2]}: ", end="") + gen_len = 0 + for response, history, past_key_values in streaming_chat( + model, tokenizer, input_query, history=history, roles=roles, + temperature = args.temperature, + top_p = args.top_p, + top_k = args.top_k, + do_sample = args.do_sample, + length_penalty = args.length_penalty, + max_new_tokens = args.max_new_tokens, + past_key_values=past_key_values, + return_past_key_values=True): + + output = response[gen_len:] + print(output, end="", flush=True) + gen_len = len(response) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', type=str, default=None, help="path to chat version model") + parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer") + parser.add_argument('--temperature', type=float, default=0.8, help="set temperature") + parser.add_argument('--top_p', type=float, default=0.95, help="set top p value") + parser.add_argument('--top_k', type=int, default=50, help="set top k value") + parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not") + parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty") + parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens") + args = parser.parse_args() + main(args) \ No newline at end of file From 8020f4263095373e4c7ad1b15e54b966a8ccb683 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 7 Mar 2024 23:36:07 +0800 Subject: [PATCH 14/43] [release] update version (#5411) --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index c2c0004f0..449d7e73a 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.5 +0.3.6 From da885ed5405c4472f18825f80c98bb505bfad23b Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Mon, 11 Mar 2024 13:49:58 +0800 Subject: [PATCH 15/43] fix tensor data update for gemini loss caluculation (#5442) --- applications/Colossal-LLaMA-2/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 2e4bab75a..d97da61e4 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -56,6 +56,7 @@ def format_numel_str(numel: int) -> str: def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor = tensor.data tensor.div_(dist.get_world_size()) return tensor From 385e85afd460a1b9a947b09c9d0f7d2628c35ad2 Mon Sep 17 00:00:00 2001 From: digger yu Date: Tue, 12 Mar 2024 11:25:16 +0800 Subject: [PATCH 16/43] [hotfix] fix typo s/keywrods/keywords etc. (#5429) --- .../ColossalQA/colossalqa/chain/retrieval_qa/base.py | 8 ++++---- applications/ColossalQA/colossalqa/prompt/prompt.py | 6 +++--- .../ColossalQA/colossalqa/retrieval_conversation_en.py | 2 +- .../colossalqa/retrieval_conversation_universal.py | 2 +- .../ColossalQA/colossalqa/retrieval_conversation_zh.py | 2 +- .../ColossalQA/examples/retrieval_conversation_chatgpt.py | 2 +- .../ColossalQA/examples/retrieval_conversation_en.py | 4 ++-- .../retrieval_conversation_en_customer_service.py | 2 +- .../examples/retrieval_conversation_universal.py | 2 +- .../ColossalQA/examples/retrieval_conversation_zh.py | 2 +- .../ColossalQA/examples/webui_demo/RAG_ChatBot.py | 2 +- colossalai/inference/README.md | 2 +- colossalai/legacy/inference/README.md | 2 +- colossalai/legacy/inference/hybridengine/engine.py | 4 ++-- colossalai/nn/optimizer/README.md | 2 +- colossalai/pipeline/schedule/generate.py | 4 ++-- .../diffusion/ldm/modules/diffusionmodules/openaimodel.py | 2 +- 17 files changed, 25 insertions(+), 25 deletions(-) diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py index e80befdac..a6e87e6be 100644 --- a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py +++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py @@ -117,8 +117,8 @@ class CustomBaseRetrievalQA(BaseRetrievalQA): ) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup) # if rejection_trigger_keywords is not given, return the response from LLM directly - rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', []) - answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) else None + rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', []) + answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) else None if answer is None: answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。") if self.combine_documents_chain.memory is not None: @@ -161,8 +161,8 @@ class CustomBaseRetrievalQA(BaseRetrievalQA): input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs ) # if rejection_trigger_keywords is not given, return the response from LLM directly - rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', []) - answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) or len(rejection_trigger_keywrods)==0 else None + rejection_trigger_keywords = inputs.get('rejection_trigger_keywords', []) + answer = answer if all([rej not in answer for rej in rejection_trigger_keywords]) or len(rejection_trigger_keywords)==0 else None if answer is None: answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。") self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer}) diff --git a/applications/ColossalQA/colossalqa/prompt/prompt.py b/applications/ColossalQA/colossalqa/prompt/prompt.py index 533f0bd55..d62249ba9 100644 --- a/applications/ColossalQA/colossalqa/prompt/prompt.py +++ b/applications/ColossalQA/colossalqa/prompt/prompt.py @@ -75,7 +75,7 @@ Assistant: 我认识一个叫张三的人 # Below are English retrieval qa prompts _EN_RETRIEVAL_QA_PROMPT = """[INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist content. -If the answer cannot be infered based on the given context, please say "I cannot answer the question based on the information given.".<> +If the answer cannot be inferred based on the given context, please say "I cannot answer the question based on the information given.".<> Use the context and chat history to answer the question. context: @@ -97,8 +97,8 @@ Chat history: Human: I have a friend, Mike. Do you know him? Assistant: Yes, I know a person named Mike -sentence: What's his favorate food? -disambiguated sentence: What's Mike's favorate food? +sentence: What's his favorite food? +disambiguated sentence: What's Mike's favorite food? [/INST] Chat history: {chat_history} diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py index d2626321d..96bce82b9 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py @@ -80,7 +80,7 @@ class EnglishRetrievalConversation: self.retrieval_chain.run( query=user_input, stop=[self.memory.human_prefix + ": "], - rejection_trigger_keywrods=["cannot answer the question"], + rejection_trigger_keywords=["cannot answer the question"], rejection_answer="Sorry, this question cannot be answered based on the information provided.", ).split("\n")[0], self.memory, diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py index 76bec715f..b23058d6d 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py @@ -103,7 +103,7 @@ class UniversalRetrievalConversation: break data_name = input("Enter a short description of the data:") separator = input( - "Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n', press ENTER directly to skip:" + "Enter a separator to force separating text into chunks, if no separator is given, the default separator is '\\n\\n', press ENTER directly to skip:" ) separator = separator if separator != "" else "\n\n" retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py index 484be21c1..4eef41947 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py @@ -87,7 +87,7 @@ class ChineseRetrievalConversation: query=user_input, stop=[""], doc_prefix="支持文档", - rejection_trigger_keywrods=["无法回答该问题"], + rejection_trigger_keywords=["无法回答该问题"], rejection_answer="抱歉,根据提供的信息无法回答该问题。", ).split("\n")[0], self.memory, diff --git a/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py b/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py index 00b920d27..1042adbf2 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py +++ b/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py @@ -61,7 +61,7 @@ if __name__ == "__main__": information_retriever.add_documents(docs=documents, cleanup="incremental", mode="by_source", embedding=embedding) prompt_template = """Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. - If the answer cannot be infered based on the given context, please don't share false information. + If the answer cannot be inferred based on the given context, please don't share false information. Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed. context: diff --git a/applications/ColossalQA/examples/retrieval_conversation_en.py b/applications/ColossalQA/examples/retrieval_conversation_en.py index e0fe46ae6..fe2b9b4db 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_en.py +++ b/applications/ColossalQA/examples/retrieval_conversation_en.py @@ -67,7 +67,7 @@ if __name__ == "__main__": break data_name = input("Enter a short description of the data:") separator = input( - "Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n'. Note that" + "Enter a separator to force separating text into chunks, if no separator is given, the default separator is '\\n\\n'. Note that" + "we use neural text spliter to split texts into chunks, the seperator only serves as a delimiter to force split long passage into" + " chunks before passing to the neural network. Press ENTER directly to skip:" ) @@ -112,7 +112,7 @@ if __name__ == "__main__": agent_response = retrieval_chain.run( query=user_input, stop=["Human: "], - rejection_trigger_keywrods=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS, + rejection_trigger_keywords=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS, rejection_answer=EN_RETRIEVAL_QA_REJECTION_ANSWER, ) agent_response = agent_response.split("\n")[0] diff --git a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py index d98a75592..d4ba73b94 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py +++ b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py @@ -142,7 +142,7 @@ if __name__ == "__main__": agent_response = retrieval_chain.run( query=user_input, stop=["Human: "], - rejection_trigger_keywrods=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS, + rejection_trigger_keywords=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS, rejection_answer=EN_RETRIEVAL_QA_REJECTION_ANSWER, ) agent_response = agent_response.split("\n")[0] diff --git a/applications/ColossalQA/examples/retrieval_conversation_universal.py b/applications/ColossalQA/examples/retrieval_conversation_universal.py index 361aa9833..5d13a63c3 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_universal.py +++ b/applications/ColossalQA/examples/retrieval_conversation_universal.py @@ -11,7 +11,7 @@ if __name__ == '__main__': parser.add_argument('--sql_file_path', type=str, default=None, help='path to the a empty folder for storing sql files for indexing') args = parser.parse_args() - # Will ask for documents path in runnning time + # Will ask for documents path in running time session = UniversalRetrievalConversation(files_en=None, files_zh=None, zh_model_path=args.zh_model_path, en_model_path=args.en_model_path, diff --git a/applications/ColossalQA/examples/retrieval_conversation_zh.py b/applications/ColossalQA/examples/retrieval_conversation_zh.py index cbbbefad7..b143b9baa 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_zh.py +++ b/applications/ColossalQA/examples/retrieval_conversation_zh.py @@ -107,7 +107,7 @@ if __name__ == "__main__": query=user_input, stop=[""], doc_prefix="支持文档", - rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS, + rejection_trigger_keywords=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS, rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER, ) print(f"Agent: {agent_response}") diff --git a/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py b/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py index c58be9c33..526328dda 100644 --- a/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py +++ b/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py @@ -140,7 +140,7 @@ class RAG_ChatBot: result = self.rag_chain.run( query=user_input, stop=[memory.human_prefix + ": "], - rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS, + rejection_trigger_keywords=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS, rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER, ) return result, memory diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index dfac7cfd9..287853a86 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -89,7 +89,7 @@ docker pull hpcaitech/colossalai-inference:v2 docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash # enter into docker container -cd /path/to/CollossalAI +cd /path/to/ColossalAI pip install -e . ``` diff --git a/colossalai/legacy/inference/README.md b/colossalai/legacy/inference/README.md index f466f46c1..63b5f2a75 100644 --- a/colossalai/legacy/inference/README.md +++ b/colossalai/legacy/inference/README.md @@ -86,7 +86,7 @@ docker pull hpcaitech/colossalai-inference:v2 docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash # enter into docker container -cd /path/to/CollossalAI +cd /path/to/ColossalAI pip install -e . # install lightllm diff --git a/colossalai/legacy/inference/hybridengine/engine.py b/colossalai/legacy/inference/hybridengine/engine.py index bb0b4c77a..bc4e4fd19 100644 --- a/colossalai/legacy/inference/hybridengine/engine.py +++ b/colossalai/legacy/inference/hybridengine/engine.py @@ -46,7 +46,7 @@ class CaiInferEngine: model = LlamaForCausalLM.from_pretrained("your_path_to_model") tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") - # assume the model is infered with 2 pipeline stages + # assume the model is inferred with 2 pipeline stages inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy()) input = ["Introduce a landmark in China ","Introduce a landmark in China "] @@ -70,7 +70,7 @@ class CaiInferEngine: max_input_len: int = 32, max_output_len: int = 32, verbose: bool = False, - # TODO: implement early_stopping, and various gerneration options + # TODO: implement early_stopping, and various generation options early_stopping: bool = False, do_sample: bool = False, num_beams: int = 1, diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index e89e6217d..d3f8badc7 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -47,7 +47,7 @@ be optimized jointly to further speed up training. 2. Model Accuracy - Communication Efficiency - - Reduce Volumn of Comm. + - Reduce Volume of Comm. - Reduce Frequency of Comm. - Memory Efficiency - Mix-Precision Training diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index d6a6aec63..48ae54c1f 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -164,7 +164,7 @@ class GenerateSchedule(PipelineSchedule): self.timestamps[self.mb_manager.idx].append(time.time()) assert ( "logits" in logits - ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + ), f"When first stage in GENERATE phase, the output should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(new_token) @@ -401,7 +401,7 @@ class GenerateSchedule(PipelineSchedule): self.timestamps[self.mb_manager.idx].append(time.time()) assert ( "logits" in logits - ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + ), f"When first stage in GENERATE phase, the output should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(new_token) # If the current micro batch is not DONE, go through blocks diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py index 614fe510f..6c80f3229 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -338,7 +338,7 @@ def count_flops_attn(model, _x, y): class QKVAttentionLegacy(nn.Module): """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping """ def __init__(self, n_heads): From f2e8b9ef9ff3032513732a699d766bcde1a3506e Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 13 Mar 2024 15:24:13 +0800 Subject: [PATCH 17/43] [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon --- .compatibility | 3 +-- .github/workflows/build_on_schedule.yml | 3 +-- .../compatiblity_test_on_dispatch.yml | 5 ++--- .github/workflows/compatiblity_test_on_pr.yml | 5 ++--- .../compatiblity_test_on_schedule.yml | 5 ++--- .../plugin/moe_hybrid_parallel_plugin.py | 3 +-- .../test_plugin/test_gemini_plugin.py | 18 ++---------------- .../test_gemini_checkpoint_io.py | 10 +--------- .../test_model/test_shard_falcon.py | 3 +++ 9 files changed, 15 insertions(+), 40 deletions(-) diff --git a/.compatibility b/.compatibility index a918cb162..d90a74b58 100644 --- a/.compatibility +++ b/.compatibility @@ -1,2 +1 @@ -2.0.0-11.7.0 -2.1.0-11.8.0 +2.1.0-12.1.0 diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 510665b46..3ff19b37b 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -67,7 +67,6 @@ jobs: --durations=0 \ tests/ env: - NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny @@ -83,4 +82,4 @@ jobs: SERVER_URL: ${{github.server_url }} REPO: ${{ github.repository }} RUN_ID: ${{ github.run_id }} - WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} \ No newline at end of file + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index a6f9582ac..764938806 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -50,7 +50,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 steps: - name: Install dependencies @@ -87,9 +87,8 @@ jobs: pip install -r requirements/requirements-test.txt - name: Unit Testing run: | - PYTHONPATH=$PWD pytest tests + PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index ede6c380a..f582b3090 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -41,7 +41,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }} @@ -82,9 +82,8 @@ jobs: pip install -r requirements/requirements-test.txt - name: Unit Testing run: | - PYTHONPATH=$PWD pytest tests + PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 1cf456ff6..3348b51ec 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -38,7 +38,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny + options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 steps: - name: Install dependencies @@ -80,10 +80,9 @@ jobs: - name: Unit Testing run: | - PYTHONPATH=$PWD pytest tests + PYTHONPATH=$PWD pytest --durations=0 tests env: DATA: /data/scratch/cifar-10 - NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LLAMA_PATH: /data/scratch/llama-tiny diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 454710fcc..ae372dd03 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -182,7 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_communication: bool = True, use_ep_inside: bool = True, custom_policy: Policy = None, - checkpoint_io: Optional[MoECheckpintIO] = None, + checkpoint_io: Optional[MoECheckpointIO] = None, ) -> None: assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -341,7 +341,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): **_kwargs, ) - def get_checkpoint_io(self) -> MoECheckpointIO: if self.checkpoint_io is None: self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 0f72d2bcd..892144772 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -1,7 +1,6 @@ from contextlib import nullcontext from typing import Optional -import pytest import torch import torch.distributed as dist @@ -12,13 +11,7 @@ from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import ( - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - skip_if_not_enough_gpus, - spawn, -) +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import COMMON_MODELS, IS_FAST_TEST, model_zoo @@ -177,12 +170,5 @@ def test_gemini_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) -@pytest.mark.largedist -@skip_if_not_enough_gpus(8) -@rerun_if_address_is_in_use() -def test_gemini_plugin_3d(early_stop: bool = True): - spawn(run_dist, 8, early_stop=early_stop) - - if __name__ == "__main__": - test_gemini_plugin(early_stop=False) \ No newline at end of file + test_gemini_plugin(early_stop=False) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index daddf6dc7..ece3b4036 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -16,7 +16,6 @@ from colossalai.testing import ( clear_cache_before_run, parameterize, rerun_if_address_is_in_use, - skip_if_not_enough_gpus, spawn, ) from tests.kit.model_zoo import model_zoo @@ -178,12 +177,5 @@ def test_gemini_ckpIO(): spawn(run_dist, 4) -@pytest.mark.largedist -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_if_address_is_in_use() -def test_gemini_ckpIO_3d(): - spawn(run_dist, 8) - - if __name__ == "__main__": - test_gemini_ckpIO() \ No newline at end of file + test_gemini_ckpIO() diff --git a/tests/test_shardformer/test_model/test_shard_falcon.py b/tests/test_shardformer/test_model/test_shard_falcon.py index 963045179..5e2efcd80 100644 --- a/tests/test_shardformer/test_model/test_shard_falcon.py +++ b/tests/test_shardformer/test_model/test_shard_falcon.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -72,6 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if stage_manager is None or stage_manager.is_first_stage(): if test_config["precision"] == "fp32": atol, rtol = 2e-4, 1e-3 + if dist.get_world_size() > 4: + atol, rtol = 4e-4, 3e-2 else: atol, rtol = 5e-3, 5e-3 check_weight(falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) From 5e16bf798022431318834a64c8386df5a5b46959 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 18 Mar 2024 15:55:11 +0800 Subject: [PATCH 18/43] [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert --- .../booster/plugin/hybrid_parallel_plugin.py | 10 +++++++++- colossalai/shardformer/modeling/gpt2.py | 16 ++++++++++++---- colossalai/shardformer/modeling/llama.py | 11 ++++++----- colossalai/shardformer/policies/base_policy.py | 2 +- colossalai/shardformer/shard/shard_config.py | 4 +++- tests/test_booster/test_plugin/test_3d_plugin.py | 2 +- 6 files changed, 32 insertions(+), 13 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 8cc76dd3e..c37a6b4df 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -199,7 +199,12 @@ def get_param_info(optim: Optimizer): if optim is None: return {} - param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} + param_info = { + "param_groups": [], + "param2id": {}, + "id2param": {}, + "param2shape": {}, + } start_index = 0 for group in optim.param_groups: packed_group = {k: v for k, v in group.items() if k != "params"} @@ -899,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. microbatch_size (int, optional): Microbatch size when using pipeline parallelism. Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. @@ -939,6 +945,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, enable_sequence_overlap: bool = False, + parallel_output: bool = True, num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, initial_scale: float = 2**16, @@ -1035,6 +1042,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, + parallel_output=parallel_output, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 3e5cc6015..1e22d9094 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -25,6 +25,7 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d +from ..layer._operation import gather_forward_split_backward class GPT2PipelineForwards: @@ -337,6 +338,9 @@ class GPT2PipelineForwards: else: loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -793,11 +797,12 @@ def get_gpt2_flash_attention_forward(): scale = scale * (1 / float(self.layer_idx + 1)) # use coloattention - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale - ) + if not hasattr(self, "attention"): + self.attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + ) - attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) + attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -1083,6 +1088,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): else: loss = loss_fct(shift_logits, shift_labels) + if not shard_config.parallel_output: + lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f20ceb2d6..eb8e9f748 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -16,7 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d -from ..layer._operation import _gather +from ..layer._operation import gather_forward_split_backward try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -290,7 +290,7 @@ class LlamaPipelineForwards: loss = loss_fct(shift_logits, shift_labels) if not shard_config.parallel_output: - logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group) if not return_dict: output = (logits,) + outputs[1:] @@ -485,8 +485,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention( + if not hasattr(self, "attention"): + self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) + attn_output = self.attention( query_states, key_states, value_states, @@ -593,7 +594,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): loss = loss_fct(shift_logits, shift_labels) if not shard_config.parallel_output: - logits = _gather(logits, -1, shard_config.tensor_parallel_process_group) + logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 1d2b7a570..9a49b1ba6 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -242,4 +242,4 @@ class Policy(ABC): end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] stage_indices.append([start_idx, end_idx]) - return stage_indices[0] if num_model_chunks == 1 else stage_indices + return stage_indices[0] if num_model_chunks == 1 else stage_indices \ No newline at end of file diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 415fc6dd5..da27341d9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -34,8 +34,10 @@ class ShardConfig: enable_all_optimization: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False - parallel_output = True + parallel_output: bool = True extra_kwargs: Dict[str, Any] = field(default_factory=dict) + # TODO padding vocab + # make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 285c4866c..38361d803 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -260,7 +260,7 @@ def run_grad_acc_test(test_args): origin_model, origin_optimizer, dataloader=dataloader ) for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): - assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) def run_dist(rank, world_size, port, early_stop: bool = True): From bd998ced039d253d2a886fce315a4660142e3b06 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Mon, 18 Mar 2024 18:31:18 +0800 Subject: [PATCH 19/43] [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights --- README.md | 9 ++++++--- docs/README-zh-Hans.md | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f045c5604..2f6aa6067 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ## Latest News +* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) * [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b) @@ -40,7 +41,7 @@

        • Colossal-AI for Real World Applications
            -
          • Open-Sora: Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million
          • +
          • Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
          • Colossal-LLaMA-2: One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution
          • ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline
          • AIGC: Acceleration of Stable Diffusion
          • @@ -126,9 +127,11 @@ distributed training and inference in a few lines. ## Colossal-AI in the Real World ### Open-Sora -[Open-Sora](https://github.com/hpcaitech/Open-Sora):Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million +[Open-Sora](https://github.com/hpcaitech/Open-Sora):Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models [[code]](https://github.com/hpcaitech/Open-Sora) -[[blog]](https://hpc-ai.com/blog/open-sora) +[[blog]](https://hpc-ai.com/blog/open-sora-v1.0) +[[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora) +[[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)

            diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 110e1a502..7d267b16f 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -24,6 +24,7 @@ ## 新闻 +* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) * [2024/01] [Construct Refined 13B Private Model With Just $5000 USD, Upgraded Colossal-AI Llama-2 Open Source](https://hpc-ai.com/blog/colossal-llama-2-13b) @@ -39,7 +40,7 @@

          • Colossal-AI 成功案例
              -
            • Open-Sora:开源Sora复现方案,成本降低46%,序列扩充至近百万
            • +
            • Open-Sora:全面开源类Sora模型参数和所有训练细节
            • Colossal-LLaMA-2: 千元预算半天训练,效果媲美主流大模型,开源可商用中文LLaMA-2
            • ColossalChat:完整RLHF流程0门槛克隆ChatGPT
            • AIGC: 加速 Stable Diffusion
            • @@ -121,9 +122,11 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 ## Colossal-AI 成功案例 ### Open-Sora -[Open-Sora](https://github.com/hpcaitech/Open-Sora):开源Sora复现方案,成本降低46%,序列扩充至近百万 +[Open-Sora](https://github.com/hpcaitech/Open-Sora):全面开源类Sora模型参数和所有训练细节 [[代码]](https://github.com/hpcaitech/Open-Sora) -[[博客]](https://hpc-ai.com/blog/open-sora) +[[博客]](https://hpc-ai.com/blog/open-sora-v1.0) +[[模型权重]](https://huggingface.co/hpcai-tech/Open-Sora) +[[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo)

              From d158fc0e642f6374aba0eecace8e5b2613e16e01 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Wed, 20 Mar 2024 16:08:41 +0800 Subject: [PATCH 20/43] [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo --- README.md | 13 ++++++------- applications/README.md | 2 +- docs/README-zh-Hans.md | 12 +++++------- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 2f6aa6067..7c234b15e 100644 --- a/README.md +++ b/README.md @@ -133,14 +133,13 @@ distributed training and inference in a few lines. [[HuggingFace model weights]](https://huggingface.co/hpcai-tech/Open-Sora) [[Demo]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo) -

              - -

              - -

              - -

              + +

              (back to top)

              ### Colossal-LLaMA-2 diff --git a/applications/README.md b/applications/README.md index 8abe1e52d..120767d5c 100644 --- a/applications/README.md +++ b/applications/README.md @@ -4,7 +4,7 @@ This directory contains the applications that are powered by Colossal-AI. The list of applications include: -- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million +- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models - [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2. - [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs. - [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF. diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 7d267b16f..93045ea6a 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -128,13 +128,11 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 [[模型权重]](https://huggingface.co/hpcai-tech/Open-Sora) [[演示样例]](https://github.com/hpcaitech/Open-Sora?tab=readme-ov-file#-latest-demo) -

              - -

              - -

              - -

              + ### Colossal-LLaMA-2 From 848a574c269b20d771cdd9cb2fa84561016180c3 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 21 Mar 2024 18:07:22 +0800 Subject: [PATCH 21/43] [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --- examples/language/grok-1/README.md | 43 ++++++++ examples/language/grok-1/grok1_policy.py | 99 +++++++++++++++++++ examples/language/grok-1/inference.py | 32 ++++++ examples/language/grok-1/inference_tp.py | 50 ++++++++++ examples/language/grok-1/requirements.txt | 4 + .../language/grok-1/run_inference_fast.sh | 11 +++ .../language/grok-1/run_inference_slow.sh | 11 +++ examples/language/grok-1/test_ci.sh | 1 + examples/language/grok-1/utils.py | 46 +++++++++ 9 files changed, 297 insertions(+) create mode 100644 examples/language/grok-1/README.md create mode 100644 examples/language/grok-1/grok1_policy.py create mode 100644 examples/language/grok-1/inference.py create mode 100644 examples/language/grok-1/inference_tp.py create mode 100644 examples/language/grok-1/requirements.txt create mode 100755 examples/language/grok-1/run_inference_fast.sh create mode 100755 examples/language/grok-1/run_inference_slow.sh create mode 100755 examples/language/grok-1/test_ci.sh create mode 100644 examples/language/grok-1/utils.py diff --git a/examples/language/grok-1/README.md b/examples/language/grok-1/README.md new file mode 100644 index 000000000..c523f9412 --- /dev/null +++ b/examples/language/grok-1/README.md @@ -0,0 +1,43 @@ +# Grok-1 Inference + +## Install + +```bash +# Make sure you install colossalai from the latest source code +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI +pip install . +cd examples/language/grok-1 +pip install -r requirements.txt +``` + +## Tokenizer preparation + +You should download the tokenizer from the official grok-1 repository. + +```bash +wget https://github.com/xai-org/grok-1/raw/main/tokenizer.model +``` + +## Inference + +You need 8x A100 80GB or equivalent GPUs to run the inference. + +We provide two scripts for inference. `run_inference_fast.sh` uses tensor parallelism provided by ColossalAI, and it is faster. `run_inference_slow.sh` uses auto device provided by transformers, and it is slower. + +Command format: + +```bash +./run_inference_fast.sh +./run_inference_slow.sh +``` + +`model_name_or_path` can be a local path or a model name from Hugging Face model hub. We provided weights on model hub, named `hpcaitech/grok-1`. + +Command example: + +```bash +./run_inference_fast.sh hpcaitech/grok-1 tokenizer.model +``` + +It will take 5-10 minutes to load checkpoints. Don't worry, it's not stuck. diff --git a/examples/language/grok-1/grok1_policy.py b/examples/language/grok-1/grok1_policy.py new file mode 100644 index 000000000..aefea6f3d --- /dev/null +++ b/examples/language/grok-1/grok1_policy.py @@ -0,0 +1,99 @@ +from typing import Dict, Union + +import torch.nn as nn + +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class Grok1Policy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self) -> nn.Module: + if self.shard_config.enable_tensor_parallelism: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}" + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + if self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } + decoder_submodule_replacement = [ + SubModuleReplacementDescription( + suffix="attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attn.o_proj", + target_module=Linear1D_Row, + ), + ] + for i in range(self.model.config.num_experts): + decoder_submodule_replacement.extend( + [ + SubModuleReplacementDescription( + suffix=f"moe_block.experts[{i}].linear", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix=f"moe_block.experts[{i}].linear_v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix=f"moe_block.experts[{i}].linear_1", + target_module=Linear1D_Row, + ), + ] + ) + + policy["DecoderLayer"] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=decoder_submodule_replacement, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key="Grok1Model", + ) + return policy + + def postprocess(self): + return self.model + + +class Grok1ModelPolicy(Grok1Policy): + pass + + +class Grok1ForCausalLMPolicy(Grok1Policy): + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = super().module_policy() + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs={"gather_output": not self.shard_config.parallel_output}, + ), + policy=policy, + target_key="Grok1ModelForCausalLM", + ) + return policy diff --git a/examples/language/grok-1/inference.py b/examples/language/grok-1/inference.py new file mode 100644 index 000000000..ca0ad0d4f --- /dev/null +++ b/examples/language/grok-1/inference.py @@ -0,0 +1,32 @@ +import time + +import torch +from sentencepiece import SentencePieceProcessor +from transformers import AutoModelForCausalLM +from utils import get_defualt_parser, inference, print_output + +if __name__ == "__main__": + parser = get_defualt_parser() + args = parser.parse_args() + start = time.time() + torch.set_default_dtype(torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained( + args.pretrained, + trust_remote_code=True, + device_map="auto", + torch_dtype=torch.bfloat16, + ) + sp = SentencePieceProcessor(model_file=args.tokenizer) + for text in args.text: + output = inference( + model, + sp, + text, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + print_output(text, sp.decode(output)) + print(f"Overall time: {time.time() - start} seconds.") diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py new file mode 100644 index 000000000..99de60e1f --- /dev/null +++ b/examples/language/grok-1/inference_tp.py @@ -0,0 +1,50 @@ +import time + +import torch +from grok1_policy import Grok1ForCausalLMPolicy +from sentencepiece import SentencePieceProcessor +from transformers import AutoModelForCausalLM +from utils import get_defualt_parser, inference, print_output + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.utils import get_current_device + +if __name__ == "__main__": + parser = get_defualt_parser() + args = parser.parse_args() + start = time.time() + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + plugin = HybridParallelPlugin( + tp_size=coordinator.world_size, + pp_size=1, + precision="bf16", + parallel_output=False, + custom_policy=Grok1ForCausalLMPolicy(), + ) + booster = Booster(plugin=plugin) + torch.set_default_dtype(torch.bfloat16) + with LazyInitContext(default_device=get_current_device()): + model = AutoModelForCausalLM.from_pretrained( + args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16 + ) + model, *_ = booster.boost(model) + sp = SentencePieceProcessor(model_file=args.tokenizer) + for text in args.text: + output = inference( + model.unwrap(), + sp, + text, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + if coordinator.is_master(): + print_output(text, sp.decode(output)) + coordinator.print_on_master(f"Overall time: {time.time() - start} seconds.") diff --git a/examples/language/grok-1/requirements.txt b/examples/language/grok-1/requirements.txt new file mode 100644 index 000000000..15d5ea53a --- /dev/null +++ b/examples/language/grok-1/requirements.txt @@ -0,0 +1,4 @@ +torch>=2.1.0,<2.2.0 +colossalai>=0.3.6 +sentencepiece==0.1.99 +transformers==4.35.0 diff --git a/examples/language/grok-1/run_inference_fast.sh b/examples/language/grok-1/run_inference_fast.sh new file mode 100755 index 000000000..0dc398c53 --- /dev/null +++ b/examples/language/grok-1/run_inference_fast.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +PRETRAINED=${1:-"hpcaitech/grok-1"} +TOKENIZER=${2:-"tokenizer.model"} + +torchrun --standalone --nproc_per_node 8 inference_tp.py --pretrained "$PRETRAINED" \ + --tokenizer "$TOKENIZER" \ + --max_new_tokens 64 \ + --text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \ + "将以下句子翻译成英语。 我喜欢看电影和读书。" \ + "All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?" diff --git a/examples/language/grok-1/run_inference_slow.sh b/examples/language/grok-1/run_inference_slow.sh new file mode 100755 index 000000000..c64dd93b9 --- /dev/null +++ b/examples/language/grok-1/run_inference_slow.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +PRETRAINED=${1:-"hpcaitech/grok-1"} +TOKENIZER=${2:-"tokenizer.model"} + +python3 inference.py --pretrained "$PRETRAINED" \ + --tokenizer "$TOKENIZER" \ + --max_new_tokens 64 \ + --text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \ + "将以下句子翻译成英语。 我喜欢看电影和读书。" \ + "All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?" diff --git a/examples/language/grok-1/test_ci.sh b/examples/language/grok-1/test_ci.sh new file mode 100755 index 000000000..f6a0d6584 --- /dev/null +++ b/examples/language/grok-1/test_ci.sh @@ -0,0 +1 @@ +pip install -r requirements.txt diff --git a/examples/language/grok-1/utils.py b/examples/language/grok-1/utils.py new file mode 100644 index 000000000..f113f852e --- /dev/null +++ b/examples/language/grok-1/utils.py @@ -0,0 +1,46 @@ +import argparse + +import torch + + +class Bcolors: + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + + +def print_output(text, output): + print(f"-----\n{Bcolors.OKBLUE}{text}{Bcolors.ENDC}{output[len(text):]}") + + +@torch.no_grad() +def inference(model, sp, text, **generate_kwargs): + input_ids = sp.encode(text) + input_ids = torch.tensor([input_ids]).cuda() + attention_mask = torch.ones_like(input_ids) + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + **generate_kwargs, + } + outputs = model.generate(**inputs) + return outputs[0].tolist() + + +def get_defualt_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1") + parser.add_argument("--tokenizer", type=str, default="tokenizer.model") + parser.add_argument("--text", type=str, nargs="+", default=["Hi, what's your name?"]) + parser.add_argument("--max_new_tokens", type=int, default=30) + parser.add_argument("--do_sample", action="store_true", default=False) + parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value") + parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering") + parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation") + return parser From 6df844b8c4946c734115b7e180b292888d857bc1 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Fri, 22 Mar 2024 15:48:12 +0800 Subject: [PATCH 22/43] [release] grok-1 314b inference (#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference --- README.md | 8 ++++++++ docs/README-zh-Hans.md | 8 ++++++++ examples/language/grok-1/README.md | 5 +++++ 3 files changed, 21 insertions(+) diff --git a/README.md b/README.md index 7c234b15e..d79ec9a66 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ## Latest News +* [2024/03] [Grok-1 of PyTorch + HuggingFace version is now available!](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) * [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) @@ -72,6 +73,7 @@
            • Inference
                +
              • Grok-1: 314B model of PyTorch + HuggingFace Inference
              • SwiftInfer:Breaks the Length Limit of LLM for Multi-Round Conversations with 46% Acceleration
              • GPT-3
              • OPT-175B Online Serving for Text Generation
              • @@ -365,6 +367,12 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt ## Inference +### Grok-1 +An easy-to-use Python + PyTorch + HuggingFace version of 314B Grok-1 Inference. +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/grok-1) +[[blog]](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) +[[HuggingFace Grok-1 PyTorch model weights]](https://huggingface.co/hpcai-tech/grok-1) +

                diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 93045ea6a..352564446 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -24,6 +24,7 @@ ## 新闻 +* [2024/03] [Grok-1 of PyTorch + HuggingFace version is now available!](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) * [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) @@ -71,6 +72,7 @@
              • 推理
                  +
                • Grok-1: 3140亿参数PyTorch + HuggingFace推理
                • SwiftInfer:打破LLM多轮对话的长度限制,推理加速46%
                • GPT-3
                • 1750亿参数OPT在线推理服务
                • @@ -358,6 +360,12 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 ## 推理 +### Grok-1 +方便易用的Python + PyTorch + HuggingFace Grok-1 推理 +[[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/grok-1) +[[博客]](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) +[[HuggingFace Grok-1 PyTorch模型权重]](https://huggingface.co/hpcai-tech/grok-1) +

                  diff --git a/examples/language/grok-1/README.md b/examples/language/grok-1/README.md index c523f9412..1d234e8eb 100644 --- a/examples/language/grok-1/README.md +++ b/examples/language/grok-1/README.md @@ -1,5 +1,10 @@ # Grok-1 Inference +An easy-to-use Python + PyTorch + HuggingFace version of 314B Grok-1. +[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/grok-1) +[[blog]](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) +[[HuggingFace Grok-1 PyTorch model weights]](https://huggingface.co/hpcai-tech/grok-1) + ## Install ```bash From 5fcd7795cd646205cc90785c398a02c8ac475b69 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Sun, 24 Mar 2024 20:24:11 +0800 Subject: [PATCH 23/43] [example] update Grok-1 inference (#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url --- examples/language/grok-1/README.md | 45 ++++++++++--------- examples/language/grok-1/inference.py | 26 ++++++++--- examples/language/grok-1/inference_tp.py | 26 ++++++++--- examples/language/grok-1/requirements.txt | 1 - .../language/grok-1/run_inference_fast.sh | 4 +- .../language/grok-1/run_inference_slow.sh | 4 +- examples/language/grok-1/utils.py | 6 +-- 7 files changed, 69 insertions(+), 43 deletions(-) diff --git a/examples/language/grok-1/README.md b/examples/language/grok-1/README.md index 1d234e8eb..2e5ec52c5 100644 --- a/examples/language/grok-1/README.md +++ b/examples/language/grok-1/README.md @@ -5,7 +5,7 @@ An easy-to-use Python + PyTorch + HuggingFace version of 314B Grok-1. [[blog]](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) [[HuggingFace Grok-1 PyTorch model weights]](https://huggingface.co/hpcai-tech/grok-1) -## Install +## Installation ```bash # Make sure you install colossalai from the latest source code @@ -16,33 +16,36 @@ cd examples/language/grok-1 pip install -r requirements.txt ``` -## Tokenizer preparation - -You should download the tokenizer from the official grok-1 repository. - -```bash -wget https://github.com/xai-org/grok-1/raw/main/tokenizer.model -``` - ## Inference You need 8x A100 80GB or equivalent GPUs to run the inference. -We provide two scripts for inference. `run_inference_fast.sh` uses tensor parallelism provided by ColossalAI, and it is faster. `run_inference_slow.sh` uses auto device provided by transformers, and it is slower. - -Command format: - -```bash -./run_inference_fast.sh -./run_inference_slow.sh -``` - -`model_name_or_path` can be a local path or a model name from Hugging Face model hub. We provided weights on model hub, named `hpcaitech/grok-1`. +We provide two scripts for inference. `run_inference_fast.sh` uses tensor parallelism provided by ColossalAI, which is faster for generation, while `run_inference_slow.sh` uses auto device provided by transformers, which is relatively slower. Command example: ```bash -./run_inference_fast.sh hpcaitech/grok-1 tokenizer.model +./run_inference_fast.sh +./run_inference_slow.sh ``` -It will take 5-10 minutes to load checkpoints. Don't worry, it's not stuck. +`MODEL_NAME_OR_PATH` can be a model name from Hugging Face model hub or a local path to PyTorch-version model checkpoints. We provided weights on model hub, named `hpcaitech/grok-1`. And you could also download the weights in advance using `git`: +```bash +git lfs install +git clone https://huggingface.co/hpcai-tech/grok-1 +``` + +It will take, depending on your Internet speed, several hours to tens of hours to download checkpoints (about 600G!), and 5-10 minutes to load checkpoints when it's ready to launch the inference. Don't worry, it's not stuck. + + +## Performance + +For request of batch size set to 1 and maximum length set to 100: + +| Method | Initialization-Duration(sec) | Average-Generation-Latency(sec) | +|-------------------------|------------------------------|---------------------------------| +| ColossalAI | 431.45 | 14.92 | +| HuggingFace Auto-Device | 426.96 | 48.38 | +| JAX | 147.61 | 56.25 | + +Tested on 8x80G NVIDIA H800. diff --git a/examples/language/grok-1/inference.py b/examples/language/grok-1/inference.py index ca0ad0d4f..a73820af9 100644 --- a/examples/language/grok-1/inference.py +++ b/examples/language/grok-1/inference.py @@ -1,8 +1,7 @@ import time import torch -from sentencepiece import SentencePieceProcessor -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, LlamaTokenizerFast from utils import get_defualt_parser, inference, print_output if __name__ == "__main__": @@ -16,11 +15,17 @@ if __name__ == "__main__": device_map="auto", torch_dtype=torch.bfloat16, ) - sp = SentencePieceProcessor(model_file=args.tokenizer) + model.eval() + init_time = time.time() - start + + # A transformers-compatible version of the grok-1 tokenizer by Xenova + # https://huggingface.co/Xenova/grok-1-tokenizer + tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer") + for text in args.text: output = inference( model, - sp, + tokenizer, text, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, @@ -28,5 +33,14 @@ if __name__ == "__main__": top_k=args.top_k, top_p=args.top_p, ) - print_output(text, sp.decode(output)) - print(f"Overall time: {time.time() - start} seconds.") + print_output(text, tokenizer.decode(output)) + + overall_time = time.time() - start + gen_latency = overall_time - init_time + avg_gen_latency = gen_latency / len(args.text) + print( + f"Initializing time: {init_time:.2f} seconds.\n" + f"Overall time: {overall_time:.2f} seconds. \n" + f"Generation latency: {gen_latency:.2f} seconds. \n" + f"Average generation latency: {avg_gen_latency:.2f} seconds. \n" + ) diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py index 99de60e1f..604de1487 100644 --- a/examples/language/grok-1/inference_tp.py +++ b/examples/language/grok-1/inference_tp.py @@ -2,8 +2,7 @@ import time import torch from grok1_policy import Grok1ForCausalLMPolicy -from sentencepiece import SentencePieceProcessor -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, LlamaTokenizerFast from utils import get_defualt_parser, inference, print_output import colossalai @@ -33,11 +32,17 @@ if __name__ == "__main__": args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16 ) model, *_ = booster.boost(model) - sp = SentencePieceProcessor(model_file=args.tokenizer) + model.eval() + init_time = time.time() - start + + # A transformers-compatible version of the grok-1 tokenizer by Xenova + # https://huggingface.co/Xenova/grok-1-tokenizer + tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer") + for text in args.text: output = inference( model.unwrap(), - sp, + tokenizer, text, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, @@ -46,5 +51,14 @@ if __name__ == "__main__": top_p=args.top_p, ) if coordinator.is_master(): - print_output(text, sp.decode(output)) - coordinator.print_on_master(f"Overall time: {time.time() - start} seconds.") + print_output(text, tokenizer.decode(output)) + + overall_time = time.time() - start + gen_latency = overall_time - init_time + avg_gen_latency = gen_latency / len(args.text) + coordinator.print_on_master( + f"Initializing time: {init_time:.2f} seconds.\n" + f"Overall time: {overall_time:.2f} seconds. \n" + f"Generation latency: {gen_latency:.2f} seconds. \n" + f"Average generation latency: {avg_gen_latency:.2f} seconds. \n" + ) diff --git a/examples/language/grok-1/requirements.txt b/examples/language/grok-1/requirements.txt index 15d5ea53a..2d94924ce 100644 --- a/examples/language/grok-1/requirements.txt +++ b/examples/language/grok-1/requirements.txt @@ -1,4 +1,3 @@ torch>=2.1.0,<2.2.0 colossalai>=0.3.6 -sentencepiece==0.1.99 transformers==4.35.0 diff --git a/examples/language/grok-1/run_inference_fast.sh b/examples/language/grok-1/run_inference_fast.sh index 0dc398c53..1ccd8383d 100755 --- a/examples/language/grok-1/run_inference_fast.sh +++ b/examples/language/grok-1/run_inference_fast.sh @@ -1,11 +1,9 @@ #!/usr/bin/env bash PRETRAINED=${1:-"hpcaitech/grok-1"} -TOKENIZER=${2:-"tokenizer.model"} torchrun --standalone --nproc_per_node 8 inference_tp.py --pretrained "$PRETRAINED" \ - --tokenizer "$TOKENIZER" \ - --max_new_tokens 64 \ + --max_new_tokens 100 \ --text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \ "将以下句子翻译成英语。 我喜欢看电影和读书。" \ "All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?" diff --git a/examples/language/grok-1/run_inference_slow.sh b/examples/language/grok-1/run_inference_slow.sh index c64dd93b9..a85782892 100755 --- a/examples/language/grok-1/run_inference_slow.sh +++ b/examples/language/grok-1/run_inference_slow.sh @@ -1,11 +1,9 @@ #!/usr/bin/env bash PRETRAINED=${1:-"hpcaitech/grok-1"} -TOKENIZER=${2:-"tokenizer.model"} python3 inference.py --pretrained "$PRETRAINED" \ - --tokenizer "$TOKENIZER" \ - --max_new_tokens 64 \ + --max_new_tokens 100 \ --text "The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence." \ "将以下句子翻译成英语。 我喜欢看电影和读书。" \ "All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books?" diff --git a/examples/language/grok-1/utils.py b/examples/language/grok-1/utils.py index f113f852e..7663127a5 100644 --- a/examples/language/grok-1/utils.py +++ b/examples/language/grok-1/utils.py @@ -20,9 +20,9 @@ def print_output(text, output): @torch.no_grad() -def inference(model, sp, text, **generate_kwargs): - input_ids = sp.encode(text) - input_ids = torch.tensor([input_ids]).cuda() +def inference(model, tokenizer, text, **generate_kwargs): + input_ids = tokenizer(text, return_tensors="pt").input_ids + input_ids = input_ids.cuda() attention_mask = torch.ones_like(input_ids) inputs = { "input_ids": input_ids, From bb0a668feeae647c247a509ed69e5f6c926a045c Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 25 Mar 2024 12:31:09 +0800 Subject: [PATCH 24/43] [hotfix] set return_outputs=False in examples and polish code (#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value --- applications/ColossalMoE/train.py | 1 - colossalai/booster/plugin/hybrid_parallel_plugin.py | 3 +++ colossalai/pipeline/schedule/one_f_one_b.py | 10 +++------- .../train_gpt_using_hybrid_parallelism.md | 2 +- .../train_vit_with_hybrid_parallelism.md | 2 +- docs/source/en/features/pipeline_parallel.md | 3 +-- docs/source/en/features/shardformer.md | 2 +- .../train_gpt_using_hybrid_parallelism.md | 2 +- .../train_vit_with_hybrid_parallelism.md | 2 +- docs/source/zh-Hans/features/pipeline_parallel.md | 3 +-- docs/source/zh-Hans/features/shardformer.md | 2 +- examples/images/vit/vit_benchmark.py | 2 +- examples/language/bert/finetune.py | 2 +- examples/language/gpt/hybridparallelism/finetune.py | 2 +- examples/language/llama2/finetune.py | 2 +- examples/language/llama2/pretrain.py | 4 ++-- examples/language/openmoe/benchmark/benchmark_cai.py | 1 - examples/language/openmoe/train.py | 1 - examples/language/opt/opt_train_demo.py | 2 +- tests/test_booster/test_plugin/test_3d_plugin.py | 2 +- .../test_hybrid_parallel_plugin_checkpoint_io.py | 5 ++--- tests/test_moe/test_moe_checkpoint.py | 1 - tests/test_pipeline/test_schedule/test_interleaved.py | 4 ++-- tests/test_pipeline/test_schedule/test_oneF_oneB.py | 4 ++-- 24 files changed, 28 insertions(+), 36 deletions(-) diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index c567038ec..99603282b 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -238,7 +238,6 @@ def main(): lambda x, y: x.loss, optimizer, return_loss=True, - return_outputs=True, ) # Backward and optimize if is_pp_last_stage: diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c37a6b4df..f51cb060c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1183,6 +1183,9 @@ class HybridParallelPlugin(PipelinePluginBase): ) -> dict: assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" + if return_outputs: + warnings.warn("return_outputs may lead to significant extra memory consumption.") + # Create a context for gradient synchronization based on the optimizer type. # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once), diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index bf2f01b10..58008b98f 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -7,7 +7,7 @@ from torch.nn import Module from torch.utils._pytree import tree_map from colossalai.accelerator import get_accelerator -from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils import get_current_device @@ -327,9 +327,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_forward(output_obj) if outputs is not None: - if isinstance(model, ModelWrapper): - model = model.unwrap() - outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) + outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} def run_forward_backward( @@ -412,9 +410,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) if outputs is not None: - if isinstance(model, ModelWrapper): - model = model.unwrap() - outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) + outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} def forward_backward_step( diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index e87eafb6e..0133dfd86 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -178,7 +178,7 @@ def train_epoch( for _ in pbar: if use_pipeline: outputs = booster.execute_pipeline( - train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + train_dataloader_iter, model, _criterion, optimizer, return_loss=True ) # Backward and optimize if is_pp_last_stage: diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md index 93fed61c3..dfc2cd596 100644 --- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -231,7 +231,7 @@ def run_forward_backward( if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: # run pipeline forward backward when enabling pp in hybrid parallel plugin output_dict = booster.execute_pipeline( - data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True + data_iter, model, criterion, optimizer, return_loss=True ) loss, outputs = output_dict["loss"], output_dict["outputs"] else: diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md index 31b20335e..d6f3cdfaf 100644 --- a/docs/source/en/features/pipeline_parallel.md +++ b/docs/source/en/features/pipeline_parallel.md @@ -198,8 +198,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: model, _criterion, optimizer, - return_loss=True, - return_outputs=True) + return_loss=True) # Backward and optimize if is_pp_last_stage: loss = outputs['loss'] diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 1e633ebc0..672945ea2 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -271,7 +271,7 @@ However, if pipeline parallel is enabled, there are several usages different fro 3. Do forward and backward passing through calling `Booster.execute_pipeline` method: ```python outputs = booster.execute_pipeline( - train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + train_dataloader_iter, model, _criterion, optimizer, return_loss=True ) ``` Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method. diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index ae941b489..cf7d19172 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -175,7 +175,7 @@ def train_epoch( for _ in pbar: if use_pipeline: outputs = booster.execute_pipeline( - train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + train_dataloader_iter, model, _criterion, optimizer, return_loss=True ) # Backward and optimize if is_pp_last_stage: diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md index 3de41601a..92775bafb 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -234,7 +234,7 @@ def run_forward_backward( if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: # run pipeline forward backward when enabling pp in hybrid parallel plugin output_dict = booster.execute_pipeline( - data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True + data_iter, model, criterion, optimizer, return_loss=True ) loss, outputs = output_dict["loss"], output_dict["outputs"] else: diff --git a/docs/source/zh-Hans/features/pipeline_parallel.md b/docs/source/zh-Hans/features/pipeline_parallel.md index e68802055..38e1fbfc5 100644 --- a/docs/source/zh-Hans/features/pipeline_parallel.md +++ b/docs/source/zh-Hans/features/pipeline_parallel.md @@ -193,8 +193,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: model, _criterion, optimizer, - return_loss=True, - return_outputs=True) + return_loss=True) # Backward and optimize if is_pp_last_stage: loss = outputs['loss'] diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index 972c48b0c..a7bcbd9f2 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -264,7 +264,7 @@ elif args.plugin == "hybrid_parallel": 3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递: ```python outputs = booster.execute_pipeline( - train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + train_dataloader_iter, model, _criterion, optimizer, return_loss=True ) ``` 该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。 diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index 078017324..32b1ec803 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -120,7 +120,7 @@ def main(): # run pipeline forward backward batch = iter([batch]) outputs = booster.execute_pipeline( - batch, model, criterion, optimizer, return_loss=True, return_outputs=True + batch, model, criterion, optimizer, return_loss=True ) else: outputs = model(**batch) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 0b1e77fff..bd6c393a7 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -148,7 +148,7 @@ def train_epoch( for _ in pbar: if use_pipeline: outputs = booster.execute_pipeline( - train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + train_dataloader_iter, model, _criterion, optimizer, return_loss=True ) # Backward and optimize if is_pp_last_device: diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index eb56ee530..888f47aaa 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -145,7 +145,7 @@ def train_epoch( for _ in pbar: if use_pipeline: outputs = booster.execute_pipeline( - train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + train_dataloader_iter, model, _criterion, optimizer, return_loss=True ) # Backward and optimize if is_pp_last_stage: diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index 3dbd0cf35..122186c30 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -271,7 +271,7 @@ def main(): for step in pbar: if use_pipeline: outputs = booster.execute_pipeline( - dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + dataloader_iter, model, _criterion, optimizer, return_loss=True ) loss = outputs["loss"] else: diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index fe7d95830..7b5805b80 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -185,7 +185,7 @@ def main(): microbatch_size=1, enable_jit_fused=False, zero_stage=0, - precision="fp32", + precision=args.mixed_precision, initial_scale=1, ) else: @@ -286,7 +286,7 @@ def main(): for step in pbar: if use_pipeline: outputs = booster.execute_pipeline( - dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + dataloader_iter, model, _criterion, optimizer, return_loss=True ) loss = outputs["loss"] else: diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 770c500d8..a6d5f8bf2 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -270,7 +270,6 @@ def main(): lambda x, y: x.loss, optimizer, return_loss=True, - return_outputs=True, ) # Backward and optimize if is_pp_last_stage: diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 89c4d5420..f3267b7c6 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -340,7 +340,6 @@ def main(): lambda x, y: x.loss, optimizer, return_loss=True, - return_outputs=True, ) # Backward and optimize if is_pp_last_stage: diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index fddbc1b40..82dff1920 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -42,7 +42,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b for _ in pbar: if use_pipeline: outputs = booster.execute_pipeline( - dataloader, model, _criterion, optimizer, return_loss=True, return_outputs=True + dataloader, model, _criterion, optimizer, return_loss=True ) # Backward and optimize if is_pp_last_stage: diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 38361d803..61558c003 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -74,7 +74,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ loss = criterion(outputs[output_key]) return loss - booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True, return_outputs=False) + booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True) optimizer.step() except Exception as e: diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index b5cb31715..557666a80 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -75,7 +75,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf model.train() if booster.plugin.stage_manager is not None: booster.execute_pipeline( - _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + _preprocess_data(data), model, _criterion, optimizer, return_loss=True ) else: output = model(**_preprocess_data(data)) @@ -109,7 +109,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf data_for_origin = data_gen_fn() if booster.plugin.stage_manager is not None: booster.execute_pipeline( - _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False + _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True ) booster.execute_pipeline( _preprocess_data(data_for_origin), @@ -117,7 +117,6 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf _criterion, new_optimizer, return_loss=True, - return_outputs=False, ) else: old_model_loss = criterion(model(**_preprocess_data(data_for_shard))) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index d6dad2d7f..10e63592a 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -49,7 +49,6 @@ def run_fwd_bwd( lambda x, y: x.loss, optimizer, return_loss=True, - return_outputs=True, ) # Backward and optimize if is_pp_last_stage: diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 0e81818eb..7aa464055 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -104,7 +104,7 @@ def run_pp( torch_loss.backward() pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) # check loss @@ -134,7 +134,7 @@ def run_pp( torch_loss = criterion(torch_output) pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) if stage_manager.is_last_stage(ignore_chunk=True): assert torch.allclose(torch_loss, pp_ret["loss"]) diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index a08dc6d27..e1a679890 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -100,7 +100,7 @@ def examine_pp(num_microbatch: int, batch_size: int): torch_loss = criterion(torch_output) torch_loss.backward() pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) # check loss @@ -130,7 +130,7 @@ def examine_pp(num_microbatch: int, batch_size: int): torch_loss = criterion(torch_output) pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True ) if stage_manager.is_last_stage(): assert torch.allclose(torch_loss, pp_ret["loss"]) From 34e909256c3fb43e5cec2c4c53437166557eb39a Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Mon, 25 Mar 2024 14:42:51 +0800 Subject: [PATCH 25/43] [release] grok-1 inference benchmark (#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark --- README.md | 12 +++++++++--- docs/README-zh-Hans.md | 14 ++++++++++---- examples/language/grok-1/README.md | 10 ++++++++-- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index d79ec9a66..26776bdf6 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ ## Latest News -* [2024/03] [Grok-1 of PyTorch + HuggingFace version is now available!](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) +* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) * [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) @@ -368,10 +368,16 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt ## Inference ### Grok-1 -An easy-to-use Python + PyTorch + HuggingFace version of 314B Grok-1 Inference. +

                  + +

                  + + - 314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, an easy-to-use Python + PyTorch + HuggingFace version for Inference. + [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/grok-1) -[[blog]](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) +[[blog]](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) [[HuggingFace Grok-1 PyTorch model weights]](https://huggingface.co/hpcai-tech/grok-1) +[[ModelScope Grok-1 PyTorch model weights]](https://www.modelscope.cn/models/colossalai/grok-1-pytorch/summary)

                  diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 352564446..6d243a808 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -24,7 +24,7 @@ ## 新闻 -* [2024/03] [Grok-1 of PyTorch + HuggingFace version is now available!](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) +* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) * [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0) * [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora) * [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer) @@ -361,10 +361,16 @@ Colossal-AI 为您提供了一系列并行组件。我们的目标是让您的 ## 推理 ### Grok-1 -方便易用的Python + PyTorch + HuggingFace Grok-1 推理 +

                  + +

                  + + - 3140亿参数Grok-1推理加速3.8倍,高效易用的PyTorch+HuggingFace版 + [[代码]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/grok-1) -[[博客]](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) -[[HuggingFace Grok-1 PyTorch模型权重]](https://huggingface.co/hpcai-tech/grok-1) +[[博客]](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) +[[HuggingFace Grok-1 PyTorch 模型权重]](https://huggingface.co/hpcai-tech/grok-1) +[[ModelScope Grok-1 PyTorch 模型权重]](https://www.modelscope.cn/models/colossalai/grok-1-pytorch/summary)

                  diff --git a/examples/language/grok-1/README.md b/examples/language/grok-1/README.md index 2e5ec52c5..de56afc1d 100644 --- a/examples/language/grok-1/README.md +++ b/examples/language/grok-1/README.md @@ -1,9 +1,15 @@ # Grok-1 Inference -An easy-to-use Python + PyTorch + HuggingFace version of 314B Grok-1. + - 314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, an easy-to-use Python + PyTorch + HuggingFace version for Inference. + [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/grok-1) -[[blog]](https://hpc-ai.com/blog/grok-1-of-pytorch-huggingface-version-is-now-available) +[[blog]](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here) [[HuggingFace Grok-1 PyTorch model weights]](https://huggingface.co/hpcai-tech/grok-1) +[[ModelScope Grok-1 PyTorch model weights]](https://www.modelscope.cn/models/colossalai/grok-1-pytorch/summary) + +

                  + +

                  ## Installation From 0688d92e2d7b65bf0092d54cf260a8a989784f6f Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 25 Mar 2024 17:21:51 +0800 Subject: [PATCH 26/43] [shardformer]Fix lm parallel. (#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix --- colossalai/shardformer/modeling/gpt2.py | 13 +++++-------- colossalai/shardformer/modeling/llama.py | 24 ++++++------------------ colossalai/shardformer/policies/gpt2.py | 5 +++-- colossalai/shardformer/policies/llama.py | 7 +++---- tests/test_optimizer/test_nvme.py | 4 +++- 5 files changed, 20 insertions(+), 33 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 1e22d9094..407338b16 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -331,7 +331,7 @@ class GPT2PipelineForwards: loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1) - if shard_config.enable_tensor_parallelism: + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: loss = cross_entropy_1d( shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group ) @@ -1078,15 +1078,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1) - if shard_config.enable_tensor_parallelism: - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - loss = loss_fct(shift_logits, shift_labels) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + if not shard_config.parallel_output: lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index eb8e9f748..d5e02b64c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -16,7 +16,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d -from ..layer._operation import gather_forward_split_backward try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -279,7 +278,7 @@ class LlamaPipelineForwards: shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) loss = cross_entropy_1d( @@ -289,9 +288,6 @@ class LlamaPipelineForwards: shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) - if not shard_config.parallel_output: - logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -578,23 +574,15 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) - if not shard_config.parallel_output: - logits = gather_forward_split_backward(logits, -1, shard_config.tensor_parallel_process_group) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 303766993..6a50d65ba 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -269,12 +269,13 @@ class GPT2LMHeadModelPolicy(GPT2Policy): GPT2LMHeadModel: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": False} + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output} ) ], - method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) } + if self.shard_config.parallel_output: + addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} module_policy.update(addon_module) if self.pipeline_stage_manager is not None: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 42bf0825b..4c454ac7f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -250,18 +250,17 @@ class LlamaForCausalLMPolicy(LlamaPolicy): policy = super().module_policy() - setattr(self.shard_config, "causal_lm", True) - if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { LlamaForCausalLM: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col) + SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}) ], - method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) } + if self.shard_config.parallel_output: + new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} policy.update(new_item) if self.pipeline_stage_manager: diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index 4ff16bb9b..3315b3256 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -1,4 +1,5 @@ import torch +import pytest from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.testing import clear_cache_before_run, parameterize @@ -16,7 +17,8 @@ def check_params_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}" - +# TODO Something wrong with ci when running this test. +@pytest.mark.skip(reason="skip because of something wrong with CI") @clear_cache_before_run() @parameterize("nvme_offload_fraction", [0.0, 0.5, 1.0]) @parameterize("nvme_offload_dir", ["./offload", None]) From 131f32a076d8cbbf9db41d60e67f0e1f1eda98a0 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Tue, 26 Mar 2024 10:19:42 +0800 Subject: [PATCH 27/43] [fix] fix grok-1 example typo (#5506) --- examples/language/grok-1/README.md | 2 +- examples/language/grok-1/run_inference_fast.sh | 2 +- examples/language/grok-1/run_inference_slow.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/language/grok-1/README.md b/examples/language/grok-1/README.md index de56afc1d..4c37934a0 100644 --- a/examples/language/grok-1/README.md +++ b/examples/language/grok-1/README.md @@ -35,7 +35,7 @@ Command example: ./run_inference_slow.sh ``` -`MODEL_NAME_OR_PATH` can be a model name from Hugging Face model hub or a local path to PyTorch-version model checkpoints. We provided weights on model hub, named `hpcaitech/grok-1`. And you could also download the weights in advance using `git`: +`MODEL_NAME_OR_PATH` can be a model name from Hugging Face model hub or a local path to PyTorch-version model checkpoints. We have provided pytorch-version checkpoint on [HuggingFace model hub](https://huggingface.co/hpcai-tech/grok-1), named `hpcai-tech/grok-1`. And you could also download the weights in advance using `git`: ```bash git lfs install git clone https://huggingface.co/hpcai-tech/grok-1 diff --git a/examples/language/grok-1/run_inference_fast.sh b/examples/language/grok-1/run_inference_fast.sh index 1ccd8383d..15a755b97 100755 --- a/examples/language/grok-1/run_inference_fast.sh +++ b/examples/language/grok-1/run_inference_fast.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -PRETRAINED=${1:-"hpcaitech/grok-1"} +PRETRAINED=${1:-"hpcai-tech/grok-1"} torchrun --standalone --nproc_per_node 8 inference_tp.py --pretrained "$PRETRAINED" \ --max_new_tokens 100 \ diff --git a/examples/language/grok-1/run_inference_slow.sh b/examples/language/grok-1/run_inference_slow.sh index a85782892..f5dba44b0 100755 --- a/examples/language/grok-1/run_inference_slow.sh +++ b/examples/language/grok-1/run_inference_slow.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -PRETRAINED=${1:-"hpcaitech/grok-1"} +PRETRAINED=${1:-"hpcai-tech/grok-1"} python3 inference.py --pretrained "$PRETRAINED" \ --max_new_tokens 100 \ From a7790a92e80b2812a89c30859e3464befb69e6a9 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 26 Mar 2024 15:09:05 +0800 Subject: [PATCH 28/43] [devops] fix example test ci (#5504) --- .github/workflows/example_check_on_dispatch.yml | 4 +--- .github/workflows/example_check_on_pr.yml | 4 +--- .github/workflows/example_check_on_schedule.yml | 3 +-- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 24e726b4f..d877b06ce 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -46,7 +46,7 @@ jobs: matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}} container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 - options: --gpus all --rm -v /data/scratch/examples-data:/data/ + options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm timeout-minutes: 15 steps: - name: 📚 Checkout @@ -60,5 +60,3 @@ jobs: echo "Testing ${dir} now" cd "${PWD}/examples/${dir}" bash test_ci.sh - env: - NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 728f059c1..6170628e1 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -78,7 +78,7 @@ jobs: matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}} container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 - options: --gpus all --rm -v /data/scratch/examples-data:/data/ + options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm timeout-minutes: 20 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} @@ -95,5 +95,3 @@ jobs: example_dir=${{ matrix.directory }} cd "${PWD}/examples/${example_dir}" bash test_ci.sh - env: - NCCL_SHM_DISABLE: 1 diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index efb131a86..2588ac824 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -35,6 +35,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 + options: --gpus all --rm -v /data/scratch/examples-data:/data/ -v /dev/shm timeout-minutes: 10 steps: - name: 📚 Checkout @@ -50,8 +51,6 @@ jobs: echo "Testing ${example_dir} now" cd "${PWD}/examples/${example_dir}" bash test_ci.sh - env: - NCCL_SHM_DISABLE: 1 - name: Notify Lark id: message-preparation From cbe34c557cb3835db8835ffdb1f96e0d7caa030a Mon Sep 17 00:00:00 2001 From: Rocky Duan Date: Tue, 26 Mar 2024 15:56:49 +0800 Subject: [PATCH 29/43] Fix ColoTensorSpec for py11 (#5440) --- colossalai/legacy/tensor/tensor_spec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/legacy/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py index 5bdd384e5..44d8d04b9 100644 --- a/colossalai/legacy/tensor/tensor_spec.py +++ b/colossalai/legacy/tensor/tensor_spec.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec @@ -17,5 +17,5 @@ class ColoTensorSpec: """ pg: ProcessGroup - dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE) + dist_attr: Optional[_DistSpec] = field(default_factory=lambda: _DistSpec(DistPlacementPattern.REPLICATE)) compute_attr: Optional[ComputeSpec] = None From 61da3fbc524c8c7939d194007d91488b89288dc5 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 26 Mar 2024 17:22:27 +0800 Subject: [PATCH 30/43] fixed layout converter caching and updated tester --- colossalai/tensor/d_tensor/layout_converter.py | 5 ++++- .../test_tensor/test_dtensor/test_layout_converter.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index abe4a86d8..667a7b78e 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -440,7 +440,10 @@ class LayoutConverter(metaclass=SingletonMeta): total_steps = 0 transform_path = [] comm_action_sequence: List[CommSpec] = [] - spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) + + src_shape = source_layout.get_sharded_shape_per_device() + dst_shape = target_layout.get_sharded_shape_per_device() + spec_pairs = ((str(source_spec.sharding_sequence), src_shape), (str(target_spec.sharding_sequence), dst_shape)) if spec_pairs in self.cached_solution: # Solution Cache hit diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 4e65401bf..3bface1d2 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -123,8 +123,15 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].logical_process_axis == 1 # checkout chached_spec_pairs_transform_path - assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path - assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence + src_shape = source_layout.get_sharded_shape_per_device() + dst_shape = target_layout.get_sharded_shape_per_device() + assert ( + layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][0] == transform_path + ) + assert ( + layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][1] + == comm_action_sequence + ) comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout) From 19e1a5cf16ead982eb8818cd69e41b06a5d23b20 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 27 Mar 2024 11:19:32 +0800 Subject: [PATCH 32/43] [shardformer] update colo attention to support custom mask (#5510) * [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests --- .github/workflows/build_on_pr.yml | 4 +- .github/workflows/build_on_schedule.yml | 2 +- .../compatiblity_test_on_dispatch.yml | 2 +- .github/workflows/compatiblity_test_on_pr.yml | 2 +- .../compatiblity_test_on_schedule.yml | 2 +- colossalai/kernel/kernel_loader.py | 24 +- colossalai/nn/layer/colo_attention.py | 209 -------- colossalai/shardformer/layer/__init__.py | 3 + colossalai/shardformer/layer/attn.py | 269 +++++++++++ colossalai/shardformer/modeling/blip2.py | 39 +- colossalai/shardformer/modeling/chatglm2.py | 125 ++--- colossalai/shardformer/modeling/gpt2.py | 448 +++++++++++++----- colossalai/shardformer/modeling/gptj.py | 363 ++++++++++---- colossalai/shardformer/modeling/llama.py | 197 ++++++-- colossalai/shardformer/modeling/opt.py | 335 +++++++++---- colossalai/shardformer/modeling/vit.py | 35 +- colossalai/shardformer/modeling/whisper.py | 302 +++++++++--- colossalai/shardformer/policies/gpt2.py | 55 ++- colossalai/shardformer/policies/gptj.py | 51 +- colossalai/shardformer/policies/llama.py | 10 + colossalai/shardformer/policies/opt.py | 58 ++- colossalai/shardformer/policies/whisper.py | 24 +- colossalai/testing/comparison.py | 30 +- extensions/README.md | 4 +- extensions/__init__.py | 10 +- extensions/base_extension.py | 4 +- extensions/cpu_adam/cpu_adam_arm.py | 4 +- extensions/cpu_adam/cpu_adam_x86.py | 8 +- extensions/cuda_extension.py | 4 +- extensions/flash_attention/__init__.py | 12 +- .../flash_attention_dao_cuda.py | 99 ++-- .../flash_attention/flash_attention_npu.py | 63 +-- .../flash_attention_sdpa_cuda.py | 56 +++ .../flash_attention_xformers_cuda.py | 94 ---- setup.py | 4 +- .../test_shardformer/test_flash_attention.py | 147 ++++++ tests/test_shardformer/test_model/_utils.py | 23 +- .../test_model/test_shard_blip2.py | 51 +- .../test_model/test_shard_chatglm2.py | 69 ++- .../test_model/test_shard_gpt2.py | 77 ++- .../test_model/test_shard_gptj.py | 78 ++- .../test_model/test_shard_llama.py | 4 +- .../test_model/test_shard_opt.py | 90 +++- .../test_model/test_shard_t5.py | 56 ++- tests/test_utils/test_flash_attention.py | 167 ------- 45 files changed, 2543 insertions(+), 1170 deletions(-) delete mode 100644 colossalai/nn/layer/colo_attention.py create mode 100644 colossalai/shardformer/layer/attn.py create mode 100644 extensions/flash_attention/flash_attention_sdpa_cuda.py delete mode 100644 extensions/flash_attention/flash_attention_xformers_cuda.py create mode 100644 tests/test_shardformer/test_flash_attention.py delete mode 100644 tests/test_utils/test_flash_attention.py diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index b01d15490..5bdadca78 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -117,7 +117,7 @@ jobs: cd TensorNVMe conda install cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - name: Store TensorNVMe Cache run: | @@ -201,4 +201,4 @@ jobs: uses: actions/upload-artifact@v3 with: name: report - path: report/ \ No newline at end of file + path: report/ diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 3ff19b37b..e560d0c00 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -44,7 +44,7 @@ jobs: cd TensorNVMe conda install cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 if: steps.check-avai.outputs.avai == 'true' diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 764938806..95a94c27b 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -66,7 +66,7 @@ jobs: cd TensorNVMe apt update && apt install -y cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index f582b3090..aef4816ef 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -60,7 +60,7 @@ jobs: cd TensorNVMe apt update && apt install -y cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 3348b51ec..3dc8a5a32 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -56,7 +56,7 @@ jobs: cd TensorNVMe apt update && apt install -y cmake pip install -r requirements.txt - pip install -v . + DISABLE_URING=1 pip install -v . - uses: actions/checkout@v2 with: ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py index 148c3e3fc..353e29b3d 100644 --- a/colossalai/kernel/kernel_loader.py +++ b/colossalai/kernel/kernel_loader.py @@ -6,7 +6,7 @@ from .extensions import ( CpuAdamX86Extension, FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, - FlashAttentionXformersCudaExtension, + FlashAttentionSdpaCudaExtension, FusedOptimizerCudaExtension, LayerNormCudaExtension, MoeCudaExtension, @@ -65,9 +65,9 @@ class KernelLoader: else: usable_exts = [] for ext in exts: - if ext.is_hardware_available(): + if ext.is_available(): # make sure the machine is compatible during kernel loading - ext.assert_hardware_compatible() + ext.assert_compatible() usable_exts.append(ext) assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." @@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): class FlashAttentionLoader(KernelLoader): - REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] + REGISTRY = [ + FlashAttentionNpuExtension, + FlashAttentionDaoCudaExtension, + FlashAttentionSdpaCudaExtension, + ] + + +class FlashAttentionWithPaddingMaskLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension] + + +class FlashAttentionWithCustomMaskLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] + + +class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader): + REGISTRY = [FlashAttentionSdpaCudaExtension] diff --git a/colossalai/nn/layer/colo_attention.py b/colossalai/nn/layer/colo_attention.py deleted file mode 100644 index 0b7011e8e..000000000 --- a/colossalai/nn/layer/colo_attention.py +++ /dev/null @@ -1,209 +0,0 @@ -import enum -import math -import warnings -from dataclasses import dataclass -from typing import Iterable, Optional, Tuple - -import torch -import torch.nn.functional as F -from einops import rearrange - -from colossalai.accelerator import get_accelerator -from colossalai.kernel.kernel_loader import FlashAttentionLoader - - -@dataclass -class SeqLenInfo: - seqlens: Iterable[int] = None - indices: torch.Tensor = None - max_seqlen: int = None - cu_seqlens: torch.Tensor = None - - @staticmethod - def materialize( - attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() - ): - if attn_mask is not None: - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() - else: - batch_size, tgt_len = size[0], size[1] - indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) - seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) - max_seqlen = max(seqlens) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) - return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - paddedcausal = 3 - - -class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, "b s ... -> (b s) ...") - ctx.shape = out.shape - # [ntokens, ...] - return out[indices] - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [ntokens, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output - grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) - # [b, s, ...] - return grad, None - - -class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - return out - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [b*s, ...] - grad = grad_output[indices] - # [ntokens, ...] - return grad, None, None, None - - -class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): - super().__init__() - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - self.attn = FlashAttentionLoader().load() - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - origin_attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None, - ): - """ - ColoAttention - - Args: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - origin_attn_mask: (nheads, q_seqlen, kv_seqlen) - bias: will not be used - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - # if flash attention is not applicable, switch to memory effcient attention - if self.attn.__name__ == "flash_attention" and ( - query.dtype not in [torch.float16, torch.bfloat16] or bias != None - ): - warnings.warn( - f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation." - ) - self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda") - - padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - # unpad - seq_len_info_q = None - seq_len_info_kv = None - if padded: - # bert style, unpad process - assert ( - attn_mask is not None - ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, ( - "attention mask is supposed to have shape (batch_size, seq_len), " - + f"but got {attn_mask.dim()} dimensions." - ) - - # bert style - if tgt_len == src_len: - seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query, key, value = self.unpad( - torch.stack([query, key, value], dim=2), seq_len_info_q.indices - ).unbind(dim=1) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - seq_len_info_kv = seq_len_info_q - else: - seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) - seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( - dim=1 - ) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - - out = self.attn( - query, - key, - value, - seq_len_info_q=seq_len_info_q, - seq_len_info_kv=seq_len_info_kv, - origin_attn_mask=origin_attn_mask, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded, - ) - - # repad - if padded: - if batch_size > 1: - out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - - if len(out.shape) == 4: - out = rearrange(out, "b s h d -> b s (h d)") - return out diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 56e8b08c4..c9b4317a6 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,3 +1,4 @@ +from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row @@ -23,4 +24,6 @@ __all__ = [ "FusedRMSNorm", "FusedLinear1D_Col", "ParallelModule", + "AttnMaskType", + "ColoAttention", ] diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py new file mode 100644 index 000000000..f3f6e59d3 --- /dev/null +++ b/colossalai/shardformer/layer/attn.py @@ -0,0 +1,269 @@ +from enum import Enum +from typing import Callable, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F + +from colossalai.kernel.kernel_loader import ( + FlashAttentionForFloatAndCustomMaskLoader, + FlashAttentionLoader, + FlashAttentionWithCustomMaskLoader, + FlashAttentionWithPaddingMaskLoader, + KernelLoader, +) + +__all__ = [ + "AttnMaskType", + "ColoAttention", +] + + +class AttnMaskType(Enum): + CUSTOM = 0 + PADDED = 1 + CAUSAL = 2 + PADDED_CAUSAL = 3 + + +def invert_mask(mask: torch.Tensor) -> torch.Tensor: + """Invert the mask tensor. + + Args: + mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv] + + Returns: + torch.Tensor: Inverted mask tensor. + """ + inverted_mask = 1.0 - mask + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min) + + +# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]: + """Get padding information from padding mask. + + Args: + padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S] + + Returns: + Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices) + """ + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return max_seqlen_in_batch, cu_seqlens, indices + + +class ColoAttention: + _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None + + @staticmethod + def _init_kernels_dispatch(): + if ColoAttention._kernel_dispatch_map is None: + # fp16/bf16 + half_dispatch_map = { + None: FlashAttentionLoader(), + AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(), + AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(), + AttnMaskType.CAUSAL: FlashAttentionLoader(), + AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(), + } + # fp32 + float_dispatch_map = { + None: FlashAttentionForFloatAndCustomMaskLoader(), + AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(), + AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(), + } + ColoAttention._kernel_dispatch_map = { + torch.float16: half_dispatch_map, + torch.bfloat16: half_dispatch_map, + torch.float32: float_dispatch_map, + } + + @staticmethod + def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable: + ColoAttention._init_kernels_dispatch() + if ( + dtype not in ColoAttention._kernel_dispatch_map + or mask_type not in ColoAttention._kernel_dispatch_map[dtype] + ): + raise ValueError( + "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type) + ) + # lazy load + if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): + ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][ + mask_type + ].load() + return ColoAttention._kernel_dispatch_map[dtype][mask_type] + + @staticmethod + def prepare_attn_kwargs( + shape_4d: Tuple[int], + dtype: torch.dtype, + device: torch.device, + q_padding_mask: Optional[torch.Tensor] = None, + kv_padding_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + ) -> Dict[str, torch.Tensor]: + """Return a dictionary of keyword arguments for attention function. It supports 4 mask type. + 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. + 2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. + 3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}. + 4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. + + Args: + shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv) + dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype`` + device (torch.device): Device of attention mask, generally should be ``hidden_states.device`` + q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor. + The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None. + kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor. + The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. + If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. + is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. + + Returns: + Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. + """ + if q_padding_mask is None and not is_causal: + return {} + assert len(shape_4d) == 4 and shape_4d[1] == 1 + b, _, s_q, s_kv = shape_4d + outputs = {} + if (q_padding_mask is None or q_padding_mask.bool().all()) and ( + kv_padding_mask is None or kv_padding_mask.bool().all() + ): + # no padding + assert is_causal + outputs["attention_mask_type"] = AttnMaskType.CAUSAL + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) + else: + if kv_padding_mask is None: + # self attention + kv_padding_mask = q_padding_mask + assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == ( + b, + s_kv, + ), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" + attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device) + max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) + max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + outputs.update( + { + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_kv": cu_seqlens_kv, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_kv": max_seqlen_kv, + "q_indices": q_indices, + "kv_indices": kv_indices, + } + ) + if is_causal: + outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) + else: + outputs["attention_mask_type"] = AttnMaskType.PADDED + attention_mask = invert_mask(attention_mask).unsqueeze(1) + outputs["attention_mask"] = attention_mask + return outputs + + @staticmethod + def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + ) -> torch.Tensor: + """Flash Attention function. It supports 4 mask type. + 1. custom mask: recv attention_mask + 2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices + 3. causal mask: recv attention_mask, attention_mask_type + 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices + + Args: + q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D] + k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D] + v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D] + attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. + attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. + cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths + of the sequences in the batch, used to index into q. + Shape should be [B+1]. Defaults to None. + cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + Shape should be [B+1]. Defaults to None. + max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None. + max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None. + indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence. + Shape should be [NUM_TOKENS]. Defaults to None. + dropout_p (float, optional): Dropout probability. Defaults to 0.0. + scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. + + Returns: + torch.Tensor: Output tensor. Shape should be [B, N, Sq, D] + """ + # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan + # this case is usaul when padding mask is used and self attention is performed + # thus, we don't use sdpa when padding mask is used + # sanity check + if attention_mask is not None: + assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." + if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): + assert ( + cu_seqlens_q is None + and cu_seqlens_kv is None + and max_seqlen_q is None + and max_seqlen_kv is None + and q_indices is None + and kv_indices is None + ) + if attention_mask_type == AttnMaskType.CUSTOM: + assert not torch.all(attention_mask != 0, dim=-1).any() + elif attention_mask_type in ( + AttnMaskType.PADDED, + AttnMaskType.PADDED_CAUSAL, + ): + assert ( + cu_seqlens_q is not None + and cu_seqlens_kv is not None + and max_seqlen_q is not None + and max_seqlen_kv is not None + and q_indices is not None + and kv_indices is not None + ) + else: + # if attention_mask is None, attention_mask_type should be the default value + assert attention_mask_type == AttnMaskType.CUSTOM + # kernel dispatch + mask_type = attention_mask_type if attention_mask is not None else None + attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type) + is_causal = attention_mask is not None and attention_mask_type in ( + AttnMaskType.CAUSAL, + AttnMaskType.PADDED_CAUSAL, + ) + return attn_func( + q, + k, + v, + dropout_p=dropout_p, + scale=scale, + attention_mask=attention_mask, + is_causal=is_causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + q_indices=q_indices, + kv_indices=kv_indices, + ) diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index d5c10541a..bd84c87c6 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -3,6 +3,8 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from colossalai.shardformer.layer import ColoAttention + def forward_fn(): def forward( @@ -62,8 +64,6 @@ def forward_fn(): def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.nn.layer.colo_attention import ColoAttention - def forward( self: Blip2Attention, hidden_states: torch.Tensor, @@ -71,16 +71,25 @@ def get_blip2_flash_attention_forward(): output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - + assert head_mask is None, "head_mask is not supported in FlashAttention" bsz, tgt_len, embed_dim = hidden_states.size() mixed_qkv = self.qkv(hidden_states) - mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) - query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale + mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + query_states, key_states, value_states = ( + mixed_qkv[0], + mixed_qkv[1], + mixed_qkv[2], ) - context_layer = attention(query_states, key_states, value_states) + + dropout_p = self.dropout.p if self.training else 0.0 + context_layer = ColoAttention.attention( + query_states, + key_states, + value_states, + dropout_p=dropout_p, + scale=self.scale, + ) + context_layer = context_layer.permute(0, 2, 1, 3).reshape(bsz, tgt_len, self.embed_dim) output = self.projection(context_layer) outputs = (output, None) @@ -93,7 +102,11 @@ def get_blip2_flash_attention_forward(): def get_jit_fused_blip2_QFormer_self_output_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput - def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self: Blip2QFormerSelfOutput, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.LayerNorm(hidden_states) @@ -105,7 +118,11 @@ def get_jit_fused_blip2_QFormer_self_output_forward(): def get_jit_fused_blip2_QFormer_output_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput - def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self: Blip2QFormerOutput, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.LayerNorm(hidden_states) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index d13bd3492..a3e000e6e 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -1,4 +1,5 @@ """ PyTorch ChatGLM model. """ + from typing import List, Optional, Tuple import torch @@ -9,63 +10,49 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig +from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel def get_flash_core_attention_forward(): - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - from .chatglm2_6b.modeling_chatglm import CoreAttention def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split(".")[0]) - if pytorch_major_version >= 2: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention( - query_layer, key_layer, value_layer, is_causal=True - ) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention( - query_layer, key_layer, value_layer, attention_mask - ) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + attention_mask_type = AttnMaskType.CAUSAL + attn_bias = torch.zeros( + query_layer.shape[0], + 1, + query_layer.shape[2], + key_layer.shape[2], + dtype=query_layer.dtype, + device=query_layer.device, + ) + temp_mask = ( + torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device) + .tril(diagonal=0) + .expand(query_layer.shape[0], 1, -1, -1) + ) + attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min) else: - # Raw attention scores - query_layer = query_layer.permute(1, 0, 2, 3).contiguous() - key_layer = key_layer.permute(1, 0, 2, 3).contiguous() - value_layer = value_layer.permute(1, 0, 2, 3).contiguous() - - scale = 1.0 / self.norm_factor - if self.coeff is not None: - scale = scale * self.coeff - - flash_attention_mask = None - attn_mask_type = None - if attention_mask is None: - attn_mask_type = AttnMaskType.causal - else: - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - attn_mask_type = AttnMaskType.paddedcausal - - attention = ColoAttention( - embed_dim=self.hidden_size_per_partition, - num_heads=self.num_attention_heads_per_partition, - dropout=self.attention_dropout.p, - scale=scale, - ) - context_layer = attention( - query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type - ) - - context_layer = context_layer.permute(1, 0, -1).contiguous() - + attention_mask_type = AttnMaskType.CUSTOM + if attention_mask is not None: + attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype) + attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min) + dropout_p = self.attention_dropout.p if self.training else 0.0 + context_layer = ColoAttention.attention( + query_layer, + key_layer, + value_layer, + attention_mask=attn_bias, + attention_mask_type=attention_mask_type, + dropout_p=dropout_p, + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) return context_layer return forward @@ -169,11 +156,17 @@ class ChatGLMPipelineForwards: if self.pre_seq_len is not None: if past_key_values is None: past_key_values = self.get_prompt( - batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, ) if attention_mask is not None: attention_mask = torch.cat( - [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, ) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): @@ -200,7 +193,9 @@ class ChatGLMPipelineForwards: if shard_config.enable_sequence_parallelism: hidden_states = split_forward_gather_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) @@ -208,7 +203,12 @@ class ChatGLMPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if self.encoder.gradient_checkpointing and self.encoder.training: layer_ret = torch.utils.checkpoint.checkpoint( - layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + past_key_values[idx], + use_cache, ) else: layer_ret = layer( @@ -224,7 +224,9 @@ class ChatGLMPipelineForwards: if shard_config.enable_sequence_parallelism: hidden_states = gather_forward_split_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -234,7 +236,14 @@ class ChatGLMPipelineForwards: hidden_states = self.encoder.final_layernorm(hidden_states) if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -368,7 +377,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): # Run encoder. # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] inputs_embeds = split_forward_gather_backward( - inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group + inputs_embeds, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, @@ -380,7 +391,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): ) hidden_states = gather_forward_split_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) if not return_dict: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 407338b16..72f923bf0 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -21,12 +21,82 @@ from transformers.models.gpt2.modeling_gpt2 import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import cross_entropy_1d from ..layer._operation import gather_forward_split_backward +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: GPT2Model, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], + attention_mask: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.Tensor], + encoder_attention_mask: Optional[torch.FloatTensor], +) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]: + batch_size, seq_len = hidden_states.shape[:2] + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + if shard_config.enable_flash_attention: + encoder_attention_mask = ColoAttention.prepare_attn_kwargs( + (encoder_batch_size, 1, seq_len, encoder_sequence_length), + dtype=hidden_states.dtype, + dtype2=encoder_hidden_states.dtype, + q_padding_mask=attention_mask, + kv_padding_mask=encoder_attention_mask, + ) + else: + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + if shard_config.enable_flash_attention: + encoder_attention_mask = {"attention_mask": None} + else: + encoder_attention_mask = None + # GPT2Attention mask. + past_key_values_length = 0 + if past_key_values is not None and past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + if shard_config.enable_flash_attention: + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_len, seq_len + past_key_values_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + elif attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + return attention_mask, encoder_attention_mask + class GPT2PipelineForwards: """ @@ -83,10 +153,10 @@ class GPT2PipelineForwards: elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] + input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] + inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -99,38 +169,7 @@ class GPT2PipelineForwards: input_shape = hidden_states.size()[:-1] device = hidden_states.device hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) - batch_size = hidden_states.shape[0] - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None + hidden_states.shape[0] # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -156,6 +195,16 @@ class GPT2PipelineForwards: output_shape = input_shape + (hidden_states.size(-1),) + attention_mask, encoder_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -171,7 +220,9 @@ class GPT2PipelineForwards: # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) # Going through held blocks. @@ -180,7 +231,7 @@ class GPT2PipelineForwards: block = self.h[i] torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: + if torch.is_tensor(attention_mask): attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) @@ -229,7 +280,9 @@ class GPT2PipelineForwards: # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) if stage_manager.is_last_stage(): @@ -245,7 +298,13 @@ class GPT2PipelineForwards: if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None ) @@ -333,7 +392,9 @@ class GPT2PipelineForwards: shift_labels = shift_labels.view(-1) if shard_config.enable_tensor_parallelism and shard_config.parallel_output: loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, ) else: loss = loss_fct(shift_logits, shift_labels) @@ -733,27 +794,18 @@ class GPT2PipelineForwards: def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - - def split_heads(tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor - def forward( self: GPT2Attention, hidden_states: Optional[Tuple[torch.FloatTensor]], layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[dict] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[dict] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + assert head_mask is None, "FlashAttention does not support head_mask" if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( @@ -766,10 +818,9 @@ def get_gpt2_flash_attention_forward(): attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = split_heads(query, self.num_heads, self.head_dim) - key = split_heads(key, self.num_heads, self.head_dim) - value = split_heads(value, self.num_heads, self.head_dim) + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past @@ -781,29 +832,14 @@ def get_gpt2_flash_attention_forward(): else: present = None - if not self.is_cross_attention: - attn_mask_type = AttnMaskType.causal - flash_attention_mask = None - if attention_mask != None: - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - if attn_mask_type == AttnMaskType.causal: - attn_mask_type == AttnMaskType.paddedcausal - else: - attn_mask_type = AttnMaskType.padding - - scale = value.size(-1) ** -0.5 + scale = 1.0 + if self.scale_attn_weights: + scale /= value.size(-1) ** 0.5 if self.scale_attn_by_inverse_layer_idx: - scale = scale * (1 / float(self.layer_idx + 1)) - - # use coloattention - if not hasattr(self, "attention"): - self.attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale - ) - - attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) - + scale /= float(self.layer_idx + 1) + dropout_p = self.attn_dropout.p if self.training else 0.0 + attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) @@ -813,6 +849,195 @@ def get_gpt2_flash_attention_forward(): return forward +def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig): + def forward( + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + attention_mask, encoder_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if torch.is_tensor(attention_mask): + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward + + def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( self, @@ -842,10 +1067,10 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] + input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] + inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -862,40 +1087,14 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): else: past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -914,6 +1113,15 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) + attention_mask, encoder_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -931,7 +1139,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -942,7 +1152,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): if layer_past is not None: layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: + if torch.is_tensor(attention_mask): attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) @@ -996,7 +1206,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): # When sequence parallelism done, gather the output tensor in forward and split it in backward hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) hidden_states = self.ln_f(hidden_states) @@ -1008,7 +1220,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None ) diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index 1990d7df3..5c254d1e7 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -19,9 +19,54 @@ from transformers.models.gptj.modeling_gptj import ( from transformers.utils import is_torch_fx_proxy, logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: GPTJModel, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]], + attention_mask: Optional[torch.FloatTensor], +) -> Optional[Union[torch.Tensor, dict]]: + batch_size, seq_len = hidden_states.shape[:2] + past_key_values_length = 0 + if past_key_values is not None and past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + if shard_config.enable_flash_attention: + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_len, seq_len + past_key_values_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + elif attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + return attention_mask + class GPTJPipelineForwards: """ @@ -96,26 +141,6 @@ class GPTJPipelineForwards: batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device - # Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x num_attention_heads x N x N @@ -139,6 +164,8 @@ class GPTJPipelineForwards: output_shape = input_shape + (hidden_states.size(-1),) + attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -154,7 +181,9 @@ class GPTJPipelineForwards: # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) # Going through held blocks. @@ -209,7 +238,9 @@ class GPTJPipelineForwards: # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) if stage_manager.is_last_stage(): @@ -223,7 +254,14 @@ class GPTJPipelineForwards: if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None ) return BaseModelOutputWithPast( @@ -530,24 +568,11 @@ class GPTJPipelineForwards: def get_gptj_flash_attention_forward(): from transformers.models.gptj.modeling_gptj import GPTJAttention - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - - def split_heads(tensor, num_attention_heads, attn_head_size, rotary): - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - tensor = tensor.view(new_shape) - if rotary or len(tensor.shape) in [4, 5]: - return tensor - else: - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") - def forward( self: GPTJAttention, hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, + attention_mask: Optional[dict] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, @@ -556,13 +581,14 @@ def get_gptj_flash_attention_forward(): Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: + assert head_mask is None, "head_mask is not supported for FlashAttention" query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) - query = split_heads(query, self.num_attention_heads, self.head_dim, True) - key = split_heads(key, self.num_attention_heads, self.head_dim, True) - value = split_heads(value, self.num_attention_heads, self.head_dim, False) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): # The logic to conditionally copy to GPU could not be traced, so we do this @@ -591,41 +617,23 @@ def get_gptj_flash_attention_forward(): key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) - # key = key.permute(0, 2, 1, 3) - # query = query.permute(0, 2, 1, 3) - key = key.to(dtype=value.dtype) # fp16 compatibility - query = query.to(dtype=value.dtype) + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) if layer_past is not None: past_key = layer_past[0] past_value = layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) if use_cache is True: present = (key, value) else: present = None - # use AttnMaskType and ColoAttention - attn_mask_type = AttnMaskType.causal - flash_attention_mask = None - if attention_mask != None: - if attn_mask_type == AttnMaskType.causal: - attn_mask_type == AttnMaskType.paddedcausal - else: - attn_mask_type = AttnMaskType.padding - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - - # use coloattention - scale = value.size(-1) ** -0.5 - - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale - ) - - attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) - + dropout_p = self.attn_dropout.p if self.training else 0.0 + attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p) + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present, None) @@ -635,6 +643,180 @@ def get_gptj_flash_attention_forward(): return forward +def gptj_model_forward_for_flash_attention(shard_config: ShardConfig): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]).long() + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + position_ids, + head_mask[i], + ) + else: + outputs = block( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward + + def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( self, @@ -662,10 +844,10 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] + input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] + inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -684,29 +866,14 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - # Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x num_attention_heads x N x N @@ -725,6 +892,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) + attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask) if self.gradient_checkpointing and self.training: if use_cache: @@ -740,7 +908,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -801,7 +971,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): # When sequence parallelism done, gather the output tensor in forward and split it in backward hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, ) hidden_states = self.ln_f(hidden_states) @@ -812,7 +984,16 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index d5e02b64c..1f17144f5 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -15,7 +15,9 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import cross_entropy_1d + +from ..layer import ColoAttention, cross_entropy_1d + try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -105,18 +107,25 @@ class LlamaPipelineForwards: # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device - ) - if LATEST_VERSION: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True ) else: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + if LATEST_VERSION: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -262,6 +271,7 @@ class LlamaPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) past_key_values = None @@ -352,6 +362,7 @@ class LlamaPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if input_ids is not None: @@ -420,8 +431,6 @@ class LlamaPipelineForwards: def get_llama_flash_attention_forward(shard_config: ShardConfig): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - llama_version = 2 try: from transformers.models.llama.modeling_llama import repeat_kv @@ -432,7 +441,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): def forward( self: LlamaAttention, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[dict] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, @@ -466,31 +475,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) - key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) - value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) - - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if not getattr(shard_config, "causal_lm", False) and attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal - - if not hasattr(self, "attention"): - self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = self.attention( - query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type, - origin_attn_mask=attention_mask, - ) + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -499,6 +487,137 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): return forward +def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig): + logger = logging.get_logger(__name__) + assert shard_config.enable_flash_attention, "Flash Attention is not enabled." + + def forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + hidden_states = inputs_embeds + + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + + def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import LlamaForCausalLM diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index d0e267eac..a26526430 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -18,6 +18,37 @@ from transformers.models.opt.modeling_opt import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: OPTModel, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values_length: int, + attention_mask: Optional[torch.FloatTensor], +): + batch_size, seq_length = hidden_states.shape[:2] + mask_seq_length = past_key_values_length + seq_length + if shard_config.enable_flash_attention: + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_length, mask_seq_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + else: + attention_mask = self.decoder._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + return attention_mask class OPTPipelineForwards: @@ -26,46 +57,6 @@ class OPTPipelineForwards: under pipeline setting. """ - @staticmethod - def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - from transformers.models.opt.modeling_opt import _make_causal_mask - - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - _dtype, - device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to( - device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - @staticmethod def opt_model_forward( self: OPTModel, @@ -81,6 +72,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward @@ -119,7 +111,7 @@ class OPTPipelineForwards: if decoder.project_in is not None: inputs_embeds = decoder.project_in(inputs_embeds) device = input_ids.device if input_ids is not None else inputs_embeds.device - _dtype = inputs_embeds.dtype + inputs_embeds.dtype else: if hidden_states is None: @@ -127,7 +119,7 @@ class OPTPipelineForwards: input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device - _dtype = hidden_states.dtype + hidden_states.dtype past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 # required mask seq length can be calculated via length of past @@ -141,13 +133,24 @@ class OPTPipelineForwards: f"{mask_seq_length} (sum of the lengths of current and past inputs)" ) - causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask( - attention_mask, input_shape, _dtype, device, past_key_values_length - ) - if stage_manager.is_first_stage(): + causal_attention_mask = _get_attention_mask( + self, + shard_config, + inputs_embeds, + past_key_values_length, + attention_mask, + ) pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) hidden_states = inputs_embeds + pos_embeds + else: + causal_attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values_length, + attention_mask, + ) if decoder.gradient_checkpointing and decoder.training: if use_cache: @@ -249,7 +252,16 @@ class OPTPipelineForwards: if stage_manager.is_last_stage(): if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -276,6 +288,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward. @@ -303,6 +316,7 @@ class OPTPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() @@ -347,6 +361,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward. @@ -371,6 +386,7 @@ class OPTPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): @@ -448,6 +464,7 @@ class OPTPipelineForwards: stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple, QuestionAnsweringModelOutput]: r""" This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward. @@ -469,6 +486,7 @@ class OPTPipelineForwards: stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + shard_config=shard_config, ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -511,49 +529,47 @@ class OPTPipelineForwards: return {"hidden_states": hidden_states} -def get_opt_flash_attention_forward(): +def get_opt_flash_attention_forward(shard_config: ShardConfig): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - def forward( self: OPTAttention, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[dict] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - + assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() - attention_input_shape = (bsz, -1, self.num_heads, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(*attention_input_shape) + query_states = self.q_proj(hidden_states) # get key, value proj if is_cross_attention and past_key_value is not None: - # reuse k, v, cross_attentions - key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) - value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self.k_proj(key_value_states).view(*attention_input_shape) - value_states = self.v_proj(key_value_states).view(*attention_input_shape) + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self.k_proj(hidden_states).view(*attention_input_shape) - value_states = self.v_proj(hidden_states).view(*attention_input_shape) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -565,38 +581,181 @@ def get_opt_flash_attention_forward(): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - src_len = key_states.size(1) - if layer_head_mask != None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) + query_states = self._shape(query_states, tgt_len, bsz) - flash_attention_mask = None - attn_mask_type = AttnMaskType.causal - if attention_mask != None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - attn_mask_type = AttnMaskType.paddedcausal + dropout_p = self.dropout if self.training else 0.0 + attn_output = ColoAttention.attention( + query_states, + key_states, + value_states, + **attention_mask, + dropout_p=dropout_p, + scale=self.scaling, + ) - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling - ) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type - ) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) + return attn_output, None, past_key_value return forward +def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig): + from transformers.models.opt.modeling_opt import OPTDecoder + + def forward( + self: OPTDecoder, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length + + # embed positions + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _get_attention_mask( + self, shard_config, inputs_embeds, past_key_values_length, attention_mask + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + causal_attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward + + def get_jit_fused_opt_decoder_layer_forward(): from transformers.models.opt.modeling_opt import OPTDecoderLayer diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index ab141a74a..e9c256a13 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -1,4 +1,3 @@ -import math from typing import List, Optional, Tuple, Union import torch @@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention def _encoder_forward( @@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: pixel_values = pixel_values.to(expected_dtype) embedding_output = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + pixel_values, + bool_masked_pos=bool_masked_pos, + interpolate_pos_encoding=interpolate_pos_encoding, ) hidden_states = embedding_output else: @@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.nn.layer.colo_attention import ColoAttention - - def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) - x = x.view(new_x_shape) - return x - def forward( self: ViTSelfAttention, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + assert head_mask is None, "head_mask is not supported for FlashAttention" mixed_query_layer = self.query(hidden_states) - key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size) - value_layer = transpose_for_scores( - self.value(hidden_states), self.num_attention_heads, self.attention_head_size - ) - query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) - scale = 1.0 / math.sqrt(self.attention_head_size) - attention = ColoAttention( - embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale - ) - context_layer = attention(query_layer, key_layer, value_layer) + dropout_p = self.dropout.p if self.training else 0.0 + context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p) - outputs = (context_layer,) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, None) if output_attentions else (context_layer,) return outputs diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index cb8b45ae7..7ccc79276 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -13,41 +13,74 @@ from transformers.modeling_outputs import ( SequenceClassifierOutput, ) from transformers.models.whisper.modeling_whisper import ( + WhisperDecoder, WhisperEncoder, WhisperForAudioClassification, WhisperForConditionalGeneration, WhisperModel, + shift_tokens_right, ) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import ColoAttention +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +def _get_attention_mask( + self: WhisperDecoder, + shard_config: ShardConfig, + hidden_states: torch.Tensor, + past_key_values_length: int, + attention_mask: Optional[torch.FloatTensor], +): + batch_size, seq_length = hidden_states.shape[:2] + mask_seq_length = past_key_values_length + seq_length + if shard_config.enable_flash_attention: + attention_mask = ColoAttention.prepare_attn_kwargs( + (batch_size, 1, seq_length, mask_seq_length), + hidden_states.dtype, + hidden_states.device, + attention_mask, + is_causal=True, + ) + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + return attention_mask def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - - def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): - return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() - def forward( self: WhisperAttention, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[dict] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - + assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention" + # for encoder, attention_mask is None + if attention_mask is None: + attention_mask = {} # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + # get query proj + query_states = self.q_proj(hidden_states) # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -55,25 +88,25 @@ def get_whisper_flash_attention_forward(): if ( is_cross_attention and past_key_value is not None - and past_key_value[0].shape[1] == key_value_states.shape[1] + and past_key_value[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) - value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention - key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - key_states = torch.cat([past_key_value[0], key_states], dim=1) - value_states = torch.cat([past_key_value[1], value_states], dim=1) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) - value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -85,38 +118,22 @@ def get_whisper_flash_attention_forward(): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - # get query proj - query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz) - src_len = key_states.size(1) - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - - attn_type = None - flash_attention_mask = None - - if self.is_decoder: - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) - if not torch.all(flash_attention_mask): - attn_type = AttnMaskType.paddedcausal - else: - attn_type = AttnMaskType.causal - - attention = ColoAttention( - embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling - ) - attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type + dropout_p = self.dropout if self.training else 0.0 + attn_output = ColoAttention.attention( + query_states, + key_states, + value_states, + **attention_mask, + dropout_p=dropout_p, + scale=self.scaling, ) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) @@ -125,6 +142,158 @@ def get_whisper_flash_attention_forward(): return forward +def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig): + def forward( + self: WhisperDecoder, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = _get_attention_mask(self, shard_config, inputs_embeds, past_key_values_length, attention_mask) + + # embed positions + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, # encoder attention mask + head_mask[idx] if head_mask is not None else None, + (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + None, # past_key_value + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + return forward + + def get_jit_fused_whisper_encoder_layer_forward(): from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer @@ -292,6 +461,7 @@ class WhisperPipelineForwards: all_attentions=None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ): r""" Args: @@ -403,7 +573,9 @@ class WhisperPipelineForwards: if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) else: @@ -411,7 +583,7 @@ class WhisperPipelineForwards: @staticmethod def whisper_decoder_forward( - self, + self: WhisperDecoder, input_ids=None, attention_mask=None, encoder_hidden_states=None, @@ -427,6 +599,7 @@ class WhisperPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ): r""" Args: @@ -535,8 +708,12 @@ class WhisperPipelineForwards: else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + attention_mask = _get_attention_mask( + self, + shard_config, + inputs_embeds, + past_key_values_length, + attention_mask, ) hidden_states = inputs_embeds + positions @@ -556,8 +733,12 @@ class WhisperPipelineForwards: ) input_shape = hidden_states.size()[:-1] - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, hidden_states, past_key_values_length + attention_mask = _get_attention_mask( + self, + shard_config, + hidden_states, + past_key_values_length, + attention_mask, ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -590,7 +771,7 @@ class WhisperPipelineForwards: encoder_hidden_states, None, # encoder attention mask head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), None, # past_key_value ) else: @@ -626,7 +807,13 @@ class WhisperPipelineForwards: if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( @@ -666,6 +853,7 @@ class WhisperPipelineForwards: encoder_hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ): r""" Returns: @@ -735,7 +923,7 @@ class WhisperPipelineForwards: elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + hidden_states=(encoder_outputs[1] if len(encoder_outputs) > 1 else None), attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) @@ -767,6 +955,7 @@ class WhisperPipelineForwards: hidden_states=hidden_states, stage_index=stage_index, decoder_starting_stage=decoder_starting_stage, + shard_config=shard_config, ) # Directly return outputs of overloaded Whisper forward if not at last stage. @@ -810,6 +999,7 @@ class WhisperPipelineForwards: encoder_hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -870,6 +1060,7 @@ class WhisperPipelineForwards: encoder_hidden_states=encoder_hidden_states, stage_index=stage_index, decoder_starting_stage=decoder_starting_stage, + shard_config=shard_config, ) if not in_decoder: return outputs @@ -920,6 +1111,7 @@ class WhisperPipelineForwards: all_attentions=None, stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, + shard_config: Optional[ShardConfig] = None, ): r""" This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward. diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6a50d65ba..fcf40fa39 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -8,6 +8,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.gpt2 import ( GPT2PipelineForwards, get_gpt2_flash_attention_forward, + get_gpt_model_forward_for_flash_attn, get_lm_forward_with_dist_cross_entropy, gpt2_sequence_parallel_forward_fn, ) @@ -75,7 +76,11 @@ class GPT2Policy(Policy): SubModuleReplacementDescription( suffix="attn.c_attn", target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "n_fused": 3, + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.c_proj", @@ -87,7 +92,11 @@ class GPT2Policy(Policy): SubModuleReplacementDescription( suffix="mlp.c_fc", target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "n_fused": 1, + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", @@ -150,6 +159,10 @@ class GPT2Policy(Policy): policy=policy, target_key=GPT2Attention, ) + if not self.shard_config.pipeline_stage_manager: + policy[GPT2Model].method_replacement = { + "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) + } if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} @@ -223,14 +236,21 @@ class GPT2Policy(Policy): num_stages=stage_manager.num_stages, ) method_replacement = { - "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + "forward": partial( + new_forward, + stage_manager=stage_manager, + shard_config=self.shard_config, + ) } else: layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -245,7 +265,9 @@ class GPT2ModelPolicy(GPT2Policy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy + model_cls=GPT2Model, + new_forward=GPT2PipelineForwards.gpt2_model_forward, + policy=policy, ) return policy @@ -299,7 +321,12 @@ class GPT2LMHeadModelPolicy(GPT2Policy): if stage_manager is not None: if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [ + { + first_stage: module.transformer.wte.weight, + last_stage: module.lm_head.weight, + } + ] return [] @@ -315,7 +342,9 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): GPT2DoubleHeadsModel: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ) ] ) @@ -350,7 +379,12 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): if stage_manager is not None: if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [ + { + first_stage: module.transformer.wte.weight, + last_stage: module.lm_head.weight, + } + ] return [] @@ -392,7 +426,10 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): addon_module = { GPT2ForTokenClassification: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) ] ) } diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 9feb826c4..b001a2009 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -6,7 +6,11 @@ from torch import Tensor, nn import colossalai.shardformer.layer as col_nn -from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward +from ..modeling.gptj import ( + GPTJPipelineForwards, + get_gptj_flash_attention_forward, + gptj_model_forward_for_flash_attention, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -71,17 +75,26 @@ class GPTJPolicy(Policy): SubModuleReplacementDescription( suffix="attn.k_proj", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.q_proj", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.v_proj", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attn.out_proj", @@ -143,6 +156,12 @@ class GPTJPolicy(Policy): policy=policy, target_key=GPTJAttention, ) + if not self.shard_config.pipeline_stage_manager: + self.append_or_create_method_replacement( + description={"forward": gptj_model_forward_for_flash_attention(self.shard_config)}, + policy=policy, + target_key=GPTJModel, + ) return policy @@ -185,7 +204,10 @@ class GPTJPolicy(Policy): stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -203,7 +225,9 @@ class GPTJModelPolicy(GPTJPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy + model_cls=GPTJModel, + new_forward=GPTJPipelineForwards.gptj_model_forward, + policy=policy, ) return policy @@ -230,7 +254,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy): GPTJForCausalLM: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ) ] ) @@ -239,7 +265,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy + model_cls=GPTJForCausalLM, + new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, + policy=policy, ) return policy @@ -256,7 +284,12 @@ class GPTJForCausalLMPolicy(GPTJPolicy): if stage_manager is not None: if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] + return [ + { + first_stage: module.transformer.wte.weight, + last_stage: module.lm_head.weight, + } + ] return [] diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 4c454ac7f..37c2c261b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -11,6 +11,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Ro from ..modeling.llama import ( LlamaPipelineForwards, get_llama_flash_attention_forward, + get_llama_model_forward_for_flash_attn, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -135,6 +136,15 @@ class LlamaPolicy(Policy): policy=policy, target_key=LlamaAttention, ) + if self.pipeline_stage_manager is None: + # replace llama model forward method + self.append_or_create_method_replacement( + description={ + "forward": get_llama_model_forward_for_flash_attn(self.shard_config), + }, + policy=policy, + target_key=LlamaModel, + ) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index a542808ba..9a74da0b8 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -9,7 +9,12 @@ from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col from .._utils import getattr_ from ..modeling.jit import get_jit_fused_dropout_add_func -from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward +from ..modeling.opt import ( + OPTPipelineForwards, + get_jit_fused_opt_decoder_layer_forward, + get_opt_decoder_forward_for_flash_attention, + get_opt_flash_attention_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -27,6 +32,7 @@ class OPTPolicy(Policy): import transformers from packaging.version import Version + # TODO: remove this version check when transformers>=4.36.0 assert Version(transformers.__version__) <= Version( "4.33.0" ), "The OPT model should run on a transformers version not greater than 4.33.0." @@ -111,7 +117,9 @@ class OPTPolicy(Policy): # optimization configuration self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True + suffix="final_layer_norm", + target_module=norm_cls, + ignore_if_not_exist=True, ), policy=policy, target_key=OPTDecoder, @@ -119,10 +127,14 @@ class OPTPolicy(Policy): self.append_or_create_submodule_replacement( description=[ SubModuleReplacementDescription( - suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True + suffix="self_attn_layer_norm", + target_module=norm_cls, + ignore_if_not_exist=True, ), SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True + suffix="final_layer_norm", + target_module=norm_cls, + ignore_if_not_exist=True, ), ], policy=policy, @@ -133,11 +145,19 @@ class OPTPolicy(Policy): if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_opt_flash_attention_forward(), + "forward": get_opt_flash_attention_forward(self.shard_config), }, policy=policy, target_key=OPTAttention, ) + if not self.shard_config.pipeline_stage_manager: + self.append_or_create_method_replacement( + description={ + "forward": get_opt_decoder_forward_for_flash_attention(self.shard_config), + }, + policy=policy, + target_key=OPTDecoder, + ) # use jit fused operator if self.shard_config.enable_jit_fused: @@ -190,7 +210,14 @@ class OPTPolicy(Policy): layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = { + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, + ) + } self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) @@ -203,7 +230,9 @@ class OPTModelPolicy(OPTPolicy): policy = super().module_policy() if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy + model_cls=OPTModel, + new_forward=OPTPipelineForwards.opt_model_forward, + policy=policy, ) return policy @@ -223,14 +252,18 @@ class OPTForCausalLMPolicy(OPTPolicy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), ), policy=policy, target_key=OPTForCausalLM, ) if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy + model_cls=OPTForCausalLM, + new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, + policy=policy, ) return policy @@ -246,7 +279,12 @@ class OPTForCausalLMPolicy(OPTPolicy): if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: num_stages = self.pipeline_stage_manager.num_stages if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): - return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] + return [ + { + 0: opt_model.model.decoder.embed_tokens.weight, + num_stages - 1: opt_model.lm_head.weight, + } + ] return [] def postprocess(self): diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index b5b5db79d..14e1e3e0f 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -13,6 +13,7 @@ from ..modeling.whisper import ( WhisperPipelineForwards, get_jit_fused_whisper_decoder_layer_forward, get_jit_fused_whisper_encoder_layer_forward, + get_whisper_decoder_forward_for_flash_attention, get_whisper_flash_attention_forward, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -31,6 +32,7 @@ class WhisperPolicy(Policy): import transformers from packaging.version import Version + # TODO: remove this version check when transformers>=4.36.0 assert Version(transformers.__version__) <= Version( "4.33.0" ), "The Whisper model should run on a transformers version not greater than 4.33.0." @@ -240,6 +242,14 @@ class WhisperPolicy(Policy): policy=policy, target_key=WhisperAttention, ) + if not self.shard_config.pipeline_stage_manager: + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_decoder_forward_for_flash_attention(self.shard_config), + }, + policy=policy, + target_key=WhisperDecoder, + ) # use jit fused operator if self.shard_config.enable_jit_fused: @@ -269,7 +279,9 @@ class WhisperPolicy(Policy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="proj_out", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ), policy=base_policy, target_key=WhisperForConditionalGeneration, @@ -326,7 +338,10 @@ class WhisperPolicy(Policy): if stage < decoder_starting_stage: return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return Policy.get_stage_index( + layers_per_stage[decoder_starting_stage:], + stage - decoder_starting_stage, + ) def get_held_layers(self) -> List[nn.Module]: assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" @@ -422,6 +437,7 @@ class WhisperPolicy(Policy): stage_manager=stage_manager, stage_index=stage_index, decoder_starting_stage=decoder_starting_stage, + shard_config=self.shard_config, ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -436,7 +452,9 @@ class WhisperModelPolicy(WhisperPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy + model_cls=WhisperModel, + new_forward=WhisperPipelineForwards.whisper_model_forward, + policy=policy, ) return policy diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 4f2a4878e..e415b5fc3 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -40,7 +40,12 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" -def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False): +def check_state_dict_equal( + d1: OrderedDict, + d2: OrderedDict, + ignore_device: bool = True, + ignore_dtype: bool = False, +): assert len(list(d1.keys())) == len( list(d2.keys()) ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" @@ -94,7 +99,12 @@ def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_devic def assert_hf_output_close( - out1: Any, out2: Any, ignore_keys: List[str] = None, track_name: str = "", atol=1e-5, rtol=1e-5 + out1: Any, + out2: Any, + ignore_keys: List[str] = None, + track_name: str = "", + atol=1e-5, + rtol=1e-5, ): """ Check if two outputs from huggingface are equal. @@ -113,7 +123,12 @@ def assert_hf_output_close( if ignore_keys is not None and k in ignore_keys: continue assert_hf_output_close( - out1[k], out2[k], track_name=f"{track_name}.{k}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + out1[k], + out2[k], + track_name=f"{track_name}.{k}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol, ) elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): # if two values are list @@ -121,12 +136,17 @@ def assert_hf_output_close( assert len(out1) == len(out2) for i in range(len(out1)): assert_hf_output_close( - out1[i], out2[i], track_name=f"{track_name}.{i}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + out1[i], + out2[i], + track_name=f"{track_name}.{i}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol, ) elif isinstance(out1, Tensor) and isinstance(out2, Tensor): if out1.shape != out2.shape: raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") - assert torch.allclose( + assert_close( out1, out2, atol=atol, rtol=rtol ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}" else: diff --git a/extensions/README.md b/extensions/README.md index 6f5feb55c..b9bde7742 100644 --- a/extensions/README.md +++ b/extensions/README.md @@ -101,13 +101,13 @@ class MyExtension(_Extension): self._support_jit = True self.priority = 10 - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: """ Return if the required hardware can be found. """ ... - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: """ Check if the hardware required by the kernel is compatible. """ diff --git a/extensions/__init__.py b/extensions/__init__.py index 9343cadda..0dbadba81 100644 --- a/extensions/__init__.py +++ b/extensions/__init__.py @@ -1,9 +1,5 @@ from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension -from .flash_attention import ( - FlashAttentionDaoCudaExtension, - FlashAttentionNpuExtension, - FlashAttentionXformersCudaExtension, -) +from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension from .layernorm import LayerNormCudaExtension from .moe import MoeCudaExtension from .optimizer import FusedOptimizerCudaExtension @@ -18,7 +14,7 @@ ALL_EXTENSIONS = [ ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension, FlashAttentionDaoCudaExtension, - FlashAttentionXformersCudaExtension, + FlashAttentionSdpaCudaExtension, FlashAttentionNpuExtension, ] @@ -31,6 +27,6 @@ __all__ = [ "ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension", "FlashAttentionDaoCudaExtension", - "FlashAttentionXformersCudaExtension", + "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension", ] diff --git a/extensions/base_extension.py b/extensions/base_extension.py index c815a7f2a..0c79c0a9e 100644 --- a/extensions/base_extension.py +++ b/extensions/base_extension.py @@ -58,13 +58,13 @@ class _Extension(ABC): return cache_directory @abstractmethod - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: """ Check if the hardware required by the kernel is available. """ @abstractmethod - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: """ Check if the hardware required by the kernel is compatible. """ diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/cpu_adam/cpu_adam_arm.py index 35bff3b55..61c4f3ed0 100644 --- a/extensions/cpu_adam/cpu_adam_arm.py +++ b/extensions/cpu_adam/cpu_adam_arm.py @@ -7,11 +7,11 @@ class CpuAdamArmExtension(_CppExtension): def __init__(self): super().__init__(name="cpu_adam_arm") - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # only arm allowed return platform.machine() == "aarch64" - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: arch = platform.machine() assert ( arch == "aarch64" diff --git a/extensions/cpu_adam/cpu_adam_x86.py b/extensions/cpu_adam/cpu_adam_x86.py index a38194167..9bbc8d851 100644 --- a/extensions/cpu_adam/cpu_adam_x86.py +++ b/extensions/cpu_adam/cpu_adam_x86.py @@ -8,15 +8,15 @@ class CpuAdamX86Extension(_CudaExtension): def __init__(self): super().__init__(name="cpu_adam_x86") - def is_hardware_available(self) -> bool: - return platform.machine() == "x86_64" and super().is_hardware_available() + def is_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_available() - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: arch = platform.machine() assert ( arch == "x86_64" ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" - super().assert_hardware_compatible() + super().assert_compatible() # necessary 4 functions def sources_files(self): diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py index 842cd9713..f1e0095b2 100644 --- a/extensions/cuda_extension.py +++ b/extensions/cuda_extension.py @@ -22,7 +22,7 @@ class _CudaExtension(_CppExtension): This function should return a list of nvcc compilation flags for extensions. """ - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # cuda extension can only be built if cuda is available try: import torch @@ -32,7 +32,7 @@ class _CudaExtension(_CppExtension): cuda_available = False return cuda_available - def assert_hardware_compatible(self) -> None: + def assert_compatible(self) -> None: from torch.utils.cpp_extension import CUDA_HOME if not CUDA_HOME: diff --git a/extensions/flash_attention/__init__.py b/extensions/flash_attention/__init__.py index 18abb6191..ea5b442aa 100644 --- a/extensions/flash_attention/__init__.py +++ b/extensions/flash_attention/__init__.py @@ -1,20 +1,14 @@ from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension from .flash_attention_npu import FlashAttentionNpuExtension -from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension +from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension try: + # TODO: remove this after updating openmoe example import flash_attention # noqa HAS_FLASH_ATTN = True except: HAS_FLASH_ATTN = False -try: - import xformers # noqa - HAS_MEM_EFF_ATTN = True -except: - HAS_MEM_EFF_ATTN = False - - -__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] +__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"] diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/flash_attention/flash_attention_dao_cuda.py index 1b7f8ac47..a2f2a52f1 100644 --- a/extensions/flash_attention/flash_attention_dao_cuda.py +++ b/extensions/flash_attention/flash_attention_dao_cuda.py @@ -5,17 +5,20 @@ class FlashAttentionDaoCudaExtension(_Extension): def __init__(self): super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: # cuda extension can only be built if cuda is available try: import torch + from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func # noqa + from flash_attn.bert_padding import index_first_axis, pad_input # noqa + cuda_available = torch.cuda.is_available() except: cuda_available = False return cuda_available - def assert_hardware_compatible(self) -> bool: + def assert_compatible(self) -> bool: pass def build_aot(self) -> None: @@ -29,65 +32,65 @@ class FlashAttentionDaoCudaExtension(_Extension): ) def load(self): - try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - except ImportError: - raise ModuleNotFoundError( - ( - "We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'" - ) - ) - from typing import Optional import torch + from einops import rearrange + from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.bert_padding import index_first_axis, pad_input + + def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor): + return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices) def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - seq_len_info_q: "SeqLenInfo", - seq_len_info_kv: "SeqLenInfo", - origin_attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, ): - """ - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - # check if the input is in allowed dtypes - if padded: - if seq_len_info_kv == None: - seq_len_info_kv = seq_len_info_q - - attn_out = flash_attn_varlen_func( + # [B, N, S, D] -> [B, S, N, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + b, s_q = q.shape[:2] + if cu_seqlens_q is not None: + # padded / padded causal + # unpad input: [B, S, N, D] -> [T, N, D] + q = _unpad_input(q, q_indices) + kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices) + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + ) + # pad output: [T, N, D] -> [B, S, N, D] + attn_output = pad_input(attn_output, q_indices, b, s_q) + else: + # causal / no attn mask + attn_output = flash_attn_func( q, k, v, - seq_len_info_q.cu_seqlens, - seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, - seq_len_info_kv.max_seqlen, - dropout_p, - scale, - causal, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, ) - else: - attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) - return attn_out + # [B, S, N, D] -> [B, N, S, D] + return attn_output.transpose(1, 2) return flash_attention diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/flash_attention/flash_attention_npu.py index 58d0f9306..0e01cefa1 100644 --- a/extensions/flash_attention/flash_attention_npu.py +++ b/extensions/flash_attention/flash_attention_npu.py @@ -5,15 +5,15 @@ class FlashAttentionNpuExtension(_Extension): def __init__(self): super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) - def is_hardware_available(self) -> bool: + def is_available(self) -> bool: try: - import torch_npu # noqa + import torch_npu - return True + return hasattr(torch_npu, "npu_fusion_attention") except: return False - def assert_hardware_compatible(self) -> bool: + def assert_compatible(self) -> bool: pass def build_aot(self) -> None: @@ -27,47 +27,36 @@ class FlashAttentionNpuExtension(_Extension): ) def load(self): - import torch - from einops import rearrange + from typing import Optional - def npu_sdpa_attention( + import torch + import torch_npu + + def flash_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - seq_len_info_q=None, - seq_len_info_kv=None, - origin_attn_mask: torch.Tensor = None, dropout_p: float = 0.0, - scale: float = 1.0, - causal=None, - padded=None, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, ): - """ - The scaled dot product attention. - - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - scale: float. The scaling of QK^T before applying softmax. - Default to 1. - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] - output = torch.nn.functional.scaled_dot_product_attention( + num_heads = q.size(1) + return torch_npu.npu_fusion_attention( q, k, v, - attn_mask=origin_attn_mask, - dropout_p=dropout_p, - is_causal=origin_attn_mask is None, + num_heads, + "BNSD", + atten_mask=attention_mask.bool(), scale=scale, - ) - output = rearrange(output, "b h s d -> b s (h d)") - return output + keep_prob=1 - dropout_p, + )[0] - return npu_sdpa_attention + return flash_attention diff --git a/extensions/flash_attention/flash_attention_sdpa_cuda.py b/extensions/flash_attention/flash_attention_sdpa_cuda.py new file mode 100644 index 000000000..d3323a6aa --- /dev/null +++ b/extensions/flash_attention/flash_attention_sdpa_cuda.py @@ -0,0 +1,56 @@ +from ..base_extension import _Extension + + +class FlashAttentionSdpaCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False) + + def is_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.") + + def build_jit(self) -> None: + raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.") + + def load(self): + from typing import Optional + + import torch + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + q_indices: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + ): + return torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_mask, + dropout_p=dropout_p, + scale=scale, + ) + + return flash_attention diff --git a/extensions/flash_attention/flash_attention_xformers_cuda.py b/extensions/flash_attention/flash_attention_xformers_cuda.py deleted file mode 100644 index 27cd823de..000000000 --- a/extensions/flash_attention/flash_attention_xformers_cuda.py +++ /dev/null @@ -1,94 +0,0 @@ -from ..base_extension import _Extension - - -class FlashAttentionXformersCudaExtension(_Extension): - def __init__(self): - super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False) - - def is_hardware_available(self) -> bool: - # cuda extension can only be built if cuda is available - try: - import torch - - cuda_available = torch.cuda.is_available() - except: - cuda_available = False - return cuda_available - - def assert_hardware_compatible(self) -> bool: - pass - - def build_aot(self) -> None: - raise NotImplementedError( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - - def build_jit(self) -> None: - raise NotImplementedError( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - - def load(self): - try: - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) - except ImportError: - raise ModuleNotFoundError( - ( - "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." - ) - ) - from typing import Optional - - import torch - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - def mem_eff_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: "SeqLenInfo", - seq_len_info_kv: "SeqLenInfo", - origin_attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - attn_bias = None - if padded: # bert style - if not causal: - attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - if padded: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) - - # shape: (b*s, n, d) - if padded: - out = out.squeeze(0) - - return out - - return mem_eff_attention diff --git a/setup.py b/setup.py index ef89481e6..c16709ad1 100644 --- a/setup.py +++ b/setup.py @@ -80,8 +80,8 @@ if BUILD_EXT: for ext_cls in ALL_EXTENSIONS: ext = ext_cls() - if ext.support_aot and ext.is_hardware_available(): - ext.assert_hardware_compatible() + if ext.support_aot and ext.is_available(): + ext.assert_compatible() op_names.append(ext.name) ext_modules.append(ext.build_aot()) diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py new file mode 100644 index 000000000..f9eab132f --- /dev/null +++ b/tests/test_shardformer/test_flash_attention.py @@ -0,0 +1,147 @@ +import math +from copy import copy + +import torch +from torch.testing import assert_close + +from colossalai.kernel.kernel_loader import ( + FlashAttentionLoader, + FlashAttentionWithCustomMaskLoader, + FlashAttentionWithPaddingMaskLoader, +) +from colossalai.shardformer.layer import AttnMaskType, ColoAttention +from colossalai.shardformer.layer.attn import invert_mask +from colossalai.testing import clear_cache_before_run, parameterize +from colossalai.utils import get_current_device, set_seed + +DTYPE = [torch.float16, torch.bfloat16] +B, N, S, D = 2, 8, 256, 32 + +TOL_MAP = { + torch.float16: {"atol": 5e-4, "rtol": 2e-3}, + torch.bfloat16: {}, +} + + +def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0): + head_dim = q.size(-1) + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) + if attn_mask is not None: + attn_weights = attn_weights + attn_mask + attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype) + attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True) + attn_output = torch.matmul(attn_weights, v) + return attn_output + + +def gen_padded_kwargs(dtype: torch.dtype): + padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device()) + padding_mask[0, : S // 4] = 0 + return ( + ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask), + padding_mask, + ) + + +def gen_padded_causal_kwargs(dtype: torch.dtype): + padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device()) + padding_mask[0, S // 2 :] = 0 + return ( + ColoAttention.prepare_attn_kwargs( + (B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True + ), + padding_mask, + ) + + +def gen_causal_kwargs(dtype: torch.dtype): + return ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, get_current_device(), is_causal=True), None + + +def gen_custom_kwargs(dtype: torch.dtype): + attn_mask = torch.ones((B, S, S), dtype=dtype, device=get_current_device()) + attn_mask[0, : S // 2, S // 2 :] = 0 + attn_mask[0, S // 2 :, : S // 2] = 0 + attn_mask[1, :, S // 4 :] = 0 + attn_mask = invert_mask(attn_mask).unsqueeze(1) + assert not torch.all(attn_mask != 0, dim=-1).any() + return {"attention_mask": attn_mask}, None + + +def post_process_kwargs_for_raw_attn(attn_kwargs: dict): + if "attention_mask_type" in attn_kwargs: + attn_kwargs = copy(attn_kwargs) + mask_type = attn_kwargs.pop("attention_mask_type") + attn_kwargs["is_causal"] = mask_type in (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) + return attn_kwargs + + +def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_mask=None): + tols = TOL_MAP[dtype] + q = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + k = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + v = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True) + q_flash = q.clone().detach().requires_grad_(True) + k_flash = k.clone().detach().requires_grad_(True) + v_flash = v.clone().detach().requires_grad_(True) + attn_mask = attn_kwargs.get("attention_mask", None) + ref_output = attention_ref(q, k, v, attn_mask) + output = attn_func(q_flash, k_flash, v_flash, **attn_kwargs) + if padding_mask is not None: + # [B, Sq] -> [B, 1, Sq, 1] + padding_mask = padding_mask[:, None, :, None].logical_not() + ref_output = ref_output.masked_fill(padding_mask, 0) + output = output.masked_fill(padding_mask, 0) + assert_close(output, ref_output, **tols) + output.mean().backward() + ref_output.mean().backward() + assert_close(q.grad, q_flash.grad, **tols) + assert_close(k.grad, k_flash.grad, **tols) + assert_close(v.grad, v_flash.grad, **tols) + + +@clear_cache_before_run() +@parameterize("dtype", DTYPE) +def test_flash_attn_func(dtype: torch.dtype): + torch.backends.cudnn.deterministic = True + set_seed(0) + # (func, name, need_postprocess) + avail_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + avail_custom_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + avail_padding_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)] + for ext_cls in FlashAttentionLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_attn_funcs.append((ext.load(), ext.name, True)) + for ext_cls in FlashAttentionWithCustomMaskLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True)) + for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY: + ext = ext_cls() + if ext.is_available(): + ext.assert_compatible() + avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True)) + + test_sets = { + "none": (lambda dtype: ({}, None), avail_attn_funcs), + "padded": (gen_padded_kwargs, avail_padding_mask_attn_funcs), + "padded_causal": (gen_padded_causal_kwargs, avail_padding_mask_attn_funcs), + "causal": (gen_causal_kwargs, avail_attn_funcs), + "custom": (gen_custom_kwargs, avail_custom_mask_attn_funcs), + } + + for mask_type, (gen_kwargs_func, attn_funcs) in test_sets.items(): + attn_kwargs, padding_mask = gen_kwargs_func(dtype) + for attn_func, name, need_postprocess in attn_funcs: + print(f"{dtype}, {name}, {mask_type}") + if need_postprocess: + check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask) + else: + check_attn_func(dtype, attn_func, attn_kwargs, padding_mask) + + +if __name__ == "__main__": + test_flash_attn_func() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 62d4d1bf3..85be9a242 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -31,6 +31,7 @@ def build_model( enable_jit_fused=False, enable_sequence_parallelism=False, use_lazy_init: bool = False, + dtype=torch.float32, ): # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() @@ -51,7 +52,7 @@ def build_model( model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) - return org_model.cuda(), sharded_model.cuda() + return org_model.cuda().to(dtype), sharded_model.cuda().to(dtype) def build_pipeline_model( @@ -132,7 +133,14 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c booster = Booster(plugin=plugin) sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) - return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster + return ( + org_model, + org_optimizer, + sharded_model, + sharded_optimizer, + criterion, + booster, + ) def run_forward_backward_with_hybrid_plugin( @@ -173,7 +181,12 @@ def run_forward_backward_with_hybrid_plugin( data_iter = iter([data]) sharded_output = booster.execute_pipeline( - data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True + data_iter, + sharded_model, + _criterion, + sharded_optimizer, + return_loss=True, + return_outputs=True, ) sharded_loss = sharded_output["loss"] else: @@ -313,7 +326,9 @@ def check_grad( def unwrap_model( - module: Module, base_model_class_name: Optional[str] = None, base_model_attribute_name: Optional[str] = None + module: Module, + base_model_class_name: Optional[str] = None, + base_model_attribute_name: Optional[str] = None, ): if isinstance(module, HybridParallelModule): module = module.unwrap() diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py index 02c15460e..2c56b0435 100644 --- a/tests/test_shardformer/test_model/test_shard_blip2.py +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -45,19 +45,51 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo "qformer.encoder.layer[0].attention.output.dense", "language_model.model.decoder.layers[0].self_attn.out_proj", ] - check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + check_grad( + blip2, + sharded_blip2, + col_layer_for_check, + atol=1e-6, + rtol=1e-5, + dim=0, + verbose=False, + ) + check_grad( + blip2, + sharded_blip2, + row_layer_for_check, + atol=1e-6, + rtol=1e-5, + dim=1, + verbose=False, + ) @parameterize("enable_fused_normalization", [True, False]) @parameterize("enable_tensor_parallelism", [True, False]) @parameterize("enable_flash_attention", [True, False]) @parameterize("enable_jit_fused", [True, False]) -def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): +def run_blip2_test( + enable_fused_normalization, + enable_tensor_parallelism, + enable_flash_attention, + enable_jit_fused, +): sub_model_zoo = model_zoo.get_sub_registry("transformers_blip2") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): org_model, sharded_model = build_model( - model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused + model_fn, + enable_fused_normalization, + enable_tensor_parallelism, + enable_flash_attention, + enable_jit_fused, + dtype=torch.float, ) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) @@ -66,7 +98,14 @@ def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable def check_blip2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_blip2_test() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 29d3592bf..78d752b69 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -11,7 +11,6 @@ from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_all_grad_tensors, check_loss, - check_output_hidden_state, check_weight, get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, @@ -25,7 +24,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -36,7 +41,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer") norm_layer_for_check = ["encoder.layers[0].input_layernorm"] - row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"] + row_layer_for_check = [ + "encoder.layers[0].self_attention.query_key_value", + "embedding.word_embeddings", + ] col_layer_for_check = ["encoder.layers[0].self_attention.dense"] # Save gradient tensors for comparison between the original model and the sharded model. @@ -94,8 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == "ChatGLMModel": - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) + # TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong + # if org_model.__class__.__name__ == "ChatGLMModel": + # check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -143,8 +152,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, { "tp_size": 2, "pp_size": 1, @@ -159,7 +180,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_chatglm_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -193,7 +220,13 @@ def run_chatglm_test(test_config): def run_chatglm_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -202,13 +235,27 @@ def run_chatglm_3d_test(test_config): def check_chatglm(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_chatglm_test() def check_chatglm_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_chatglm_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3155420f1..d59d7e4ad 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -47,10 +53,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 col_layer_grads = get_grad_tensors_for_check( - gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + gpt2, + sharded_gpt2, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) row_layer_grads = get_grad_tensors_for_check( - gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + gpt2, + sharded_gpt2, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) norm_layer_grads = get_grad_tensors_for_check( @@ -90,7 +110,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + check_weight( + gpt2, + sharded_gpt2, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -123,14 +152,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -138,7 +167,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": True, "precision": "fp32", }, @@ -167,7 +196,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_gpt2_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -202,7 +237,13 @@ def run_gpt2_test(test_config): def run_gpt2_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -211,13 +252,27 @@ def run_gpt2_3d_test(test_config): def check_gpt2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_gpt2_test() def check_gpt2_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_gpt2_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_gptj.py b/tests/test_shardformer/test_model/test_shard_gptj.py index c83eaaa09..009202a0d 100644 --- a/tests/test_shardformer/test_model/test_shard_gptj.py +++ b/tests/test_shardformer/test_model/test_shard_gptj.py @@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -46,11 +52,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 col_layer_grads = get_grad_tensors_for_check( - gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + gptj, + sharded_gptj, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) row_layer_grads = get_grad_tensors_for_check( - gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + gptj, + sharded_gptj, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -77,7 +97,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + check_weight( + gptj, + sharded_gptj, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -110,14 +139,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -125,7 +154,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 4, - "enable_all_optimization": True, + "enable_all_optimization": False, #'use_lazy_init': True, "precision": "fp32", }, @@ -154,7 +183,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_gptj_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -189,7 +224,13 @@ def run_gptj_test(test_config): def run_gptj_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -198,15 +239,30 @@ def run_gptj_3d_test(test_config): def check_gptj(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_gptj_test() def check_gptj_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_gptj_3d_test() + @pytest.mark.skip("TODO check_gptj has something wrong.") @pytest.mark.dist @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index c7edcfb35..126ff23a9 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -112,7 +112,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -124,7 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2, "pp_size": 1, diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index d21ab264d..523ed879b 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -29,7 +29,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -39,7 +45,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, opt_model = unwrap_model(org_model, "OPTModel", "model") shard_opt_model = unwrap_model(sharded_model, "OPTModel", "model") - row_layer_for_check = ["decoder.layers[0].self_attn.q_proj", "decoder.embed_tokens"] # 'decoder.embed_tokens' + row_layer_for_check = [ + "decoder.layers[0].self_attn.q_proj", + "decoder.embed_tokens", + ] # 'decoder.embed_tokens' col_layer_for_check = ["decoder.layers[0].self_attn.out_proj"] # Save gradient tensors for comparison between the original model and the sharded model. @@ -50,10 +59,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 4e-2, 4e-2 row_layer_grads = get_grad_tensors_for_check( - opt_model, shard_opt_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + opt_model, + shard_opt_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, ) col_layer_grads = get_grad_tensors_for_check( - opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -80,7 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 check_weight( - opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + opt_model, + shard_opt_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, ) # check grads @@ -110,8 +140,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, { "tp_size": 2, "pp_size": 1, @@ -135,7 +177,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_opt_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -169,7 +217,13 @@ def run_opt_test(test_config): def run_opt_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -178,13 +232,27 @@ def run_opt_3d_test(test_config): def check_OPTModel(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_opt_test() def check_opt_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_opt_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 22c201458..9b22d54d7 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster, ) stage_manager = booster.plugin.stage_manager @@ -71,7 +77,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) + check_weight( + t5, + sharded_t5, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -104,7 +119,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, @@ -117,7 +132,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", }, - {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, { "tp_size": 2, "pp_size": 1, @@ -144,7 +158,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, def run_t5_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): # skip 4-stage pp test for t5_encoder if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model": continue @@ -185,7 +205,13 @@ def run_t5_test(test_config): def run_t5_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + for name, ( + model_fn, + data_gen_fn, + output_transform_fn, + loss_fn, + _, + ) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -194,13 +220,27 @@ def run_t5_3d_test(test_config): def check_t5(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_t5_test() def check_t5_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch( + config={}, + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) run_t5_3d_test() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py deleted file mode 100644 index 3ec170004..000000000 --- a/tests/test_utils/test_flash_attention.py +++ /dev/null @@ -1,167 +0,0 @@ -import math - -import pytest -import torch -from einops import rearrange - -from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN -from colossalai.testing import clear_cache_before_run, parameterize - -if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention - -DTYPE = [torch.float16, torch.bfloat16, torch.float32] - - -def attention_ref(q, k, v, attn_mask=None, causal=False): - """ - attention output of the control group - """ - dtype_og = q.dtype - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - d = q.shape[-1] - scale = 1.0 / math.sqrt(d) - scores = torch.einsum("bthd,bshd->bhts", q * scale, k) - - if attn_mask is not None: - scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) - scores.masked_fill_(causal_mask, float("-inf")) - attention = torch.softmax(scores, dim=-1) - - output = torch.einsum("bhts,bshd->bthd", attention, v) - output = rearrange(output, "b s h d -> b s (h d)") - - # Modify the data at the positions of the mask to 0 - if attn_mask is not None: - output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0) - - return output.to(dtype=dtype_og) - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_gpt(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)] - mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, mask, causal=True) - - # check gradients - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_bert(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - # attention mask of shape [B, S] with zero padding to max length S - mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda") - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, mask, causal=False) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_attention_no_mask(proj_shape, dtype, dropout): - (B, S, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v) - - assert list(y.shape) == [B, S, D] - - out_ref = attention_ref(q, k, v, None, causal=False) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" - - -@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") -@clear_cache_before_run() -@parameterize("proj_shape", [(6, 24, 8, 4, 16)]) -@parameterize("dtype", DTYPE) -@parameterize("dropout", [0.0]) -def test_cross_attention(proj_shape, dtype, dropout): - (B, S, T, H, D_HEAD) = proj_shape - D = H * D_HEAD - - q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - - attn = ColoAttention(D, H, dropout=dropout) - y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) - - assert list(y.shape) == [B, T, D] - - out_ref = attention_ref(q, k, v, None, causal=True) - - dy = torch.rand_like(y) - grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy) - grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy) - - torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}" - torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}" - torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}" - torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}" From e6707a6e8d81d584597d7b8e5d8578158544d8f5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 27 Mar 2024 11:21:03 +0800 Subject: [PATCH 33/43] [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions --- colossalai/shardformer/modeling/gpt2.py | 1 - colossalai/shardformer/modeling/llama.py | 2 -- colossalai/shardformer/policies/gpt2.py | 8 ++++++-- colossalai/shardformer/policies/llama.py | 10 ++++++++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 72f923bf0..ea22cfb15 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -1302,7 +1302,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group ) - if not shard_config.parallel_output: lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1f17144f5..29dc8200f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -15,10 +15,8 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig - from ..layer import ColoAttention, cross_entropy_1d - try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index fcf40fa39..5b43ecaed 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -291,13 +291,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy): GPT2LMHeadModel: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output} + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": not self.shard_config.parallel_output}, ) ], ) } if self.shard_config.parallel_output: - addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + addon_module[GPT2LMHeadModel].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 37c2c261b..db8468713 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -265,12 +265,18 @@ class LlamaForCausalLMPolicy(LlamaPolicy): new_item = { LlamaForCausalLM: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}) + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs={"gather_output": not self.shard_config.parallel_output}, + ) ], ) } if self.shard_config.parallel_output: - new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + new_item[LlamaForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } policy.update(new_item) if self.pipeline_stage_manager: From 00525f77724428e7d883893d07bbfbf4dd1ad35e Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Wed, 27 Mar 2024 01:57:00 -0400 Subject: [PATCH 34/43] [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen --- .../colossal_moe/models/mixtral_policy.py | 2 +- .../shardformer/policies/base_policy.py | 7 +- colossalai/shardformer/policies/bert.py | 67 ++++++++++++++----- colossalai/shardformer/policies/bloom.py | 4 +- colossalai/shardformer/policies/chatglm2.py | 4 +- colossalai/shardformer/policies/falcon.py | 4 +- colossalai/shardformer/policies/gpt2.py | 8 +-- colossalai/shardformer/policies/gptj.py | 4 +- colossalai/shardformer/policies/llama.py | 8 +-- colossalai/shardformer/policies/opt.py | 4 +- colossalai/shardformer/policies/t5.py | 32 +++++---- colossalai/shardformer/policies/vit.py | 4 +- colossalai/shardformer/policies/whisper.py | 30 ++++----- .../language/openmoe/model/openmoe_policy.py | 17 ++--- .../test_plugin/test_3d_plugin.py | 8 +-- .../test_t5_pipeline_utils.py | 18 +++-- .../test_whisper_pipeline_utils.py | 16 +++-- .../test_layer/test_dist_crossentropy.py | 5 +- 18 files changed, 136 insertions(+), 106 deletions(-) diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py index 218b05b27..23ffbf5d3 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -110,7 +110,7 @@ class MixtralPolicy(Policy): module = self.model.model layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 9a49b1ba6..762e75481 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -197,8 +197,7 @@ class Policy(ABC): """ return [] - @staticmethod - def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """Divide layers into stages""" quotient = num_layers // num_stages remainder = num_layers % num_stages @@ -213,8 +212,8 @@ class Policy(ABC): layers_per_stage[i] += 1 return layers_per_stage - @staticmethod def get_stage_index( + self, layers_per_stage: List[int], stage: int, num_model_chunks: int = 1, @@ -242,4 +241,4 @@ class Policy(ABC): end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] stage_indices.append([start_idx, end_idx]) - return stage_indices[0] if num_model_chunks == 1 else stage_indices \ No newline at end of file + return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0ab63b765..4d50a3c99 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -84,17 +84,26 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -112,7 +121,10 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="output.dense", @@ -214,7 +226,9 @@ class BertPolicy(Policy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ), policy=base_policy, target_key=BertLMPredictionHead, @@ -241,7 +255,9 @@ class BertPolicy(Policy): "_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict, } self.append_or_create_method_replacement( - description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead + description=method_replacement, + policy=base_policy, + target_key=BertLMPredictionHead, ) return base_policy @@ -264,24 +280,32 @@ class BertPolicy(Policy): if stage_manager.is_interleave: layers_per_stage = self.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks, ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, num_stages=stage_manager.num_stages, ) method_replacement = { - "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + "forward": partial( + new_forward, + stage_manager=stage_manager, + shard_config=self.shard_config, + ) } else: - layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } @@ -301,9 +325,10 @@ class BertPolicy(Policy): if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = self.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks, ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -320,7 +345,7 @@ class BertPolicy(Policy): layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.pooler) @@ -336,7 +361,9 @@ class BertModelPolicy(BertPolicy): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy + model_cls=BertModel, + new_forward=BertPipelineForwards.bert_model_forward, + policy=policy, ) return policy @@ -399,7 +426,9 @@ class BertLMHeadModelPolicy(BertPolicy): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy + model_cls=BertLMHeadModel, + new_forward=BertPipelineForwards.bert_lm_head_model_forward, + policy=policy, ) return policy @@ -437,7 +466,9 @@ class BertForMaskedLMPolicy(BertPolicy): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy + model_cls=BertForMaskedLM, + new_forward=BertPipelineForwards.bert_for_masked_lm_forward, + policy=policy, ) return policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index eddfafdcb..e4714c8c1 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -203,8 +203,8 @@ class BloomPolicy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index d1ad9f914..cbe6254d1 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 5c148880f..16bbc3f23 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -161,8 +161,8 @@ class FalconPolicy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5b43ecaed..d1a8c9dce 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -188,7 +188,7 @@ class GPT2Policy(Policy): layers_per_stage = self.distribute_layers( len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -229,7 +229,7 @@ class GPT2Policy(Policy): layers_per_stage = self.distribute_layers( len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -243,8 +243,8 @@ class GPT2Policy(Policy): ) } else: - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index b001a2009..b24443298 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -200,8 +200,8 @@ class GPTJPolicy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index db8468713..daa7708c8 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -167,7 +167,7 @@ class LlamaPolicy(Policy): layers_per_stage = self.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -178,8 +178,8 @@ class LlamaPolicy(Policy): } else: - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -207,7 +207,7 @@ class LlamaPolicy(Policy): layers_per_stage = self.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9a74da0b8..683f3a9d5 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -208,8 +208,8 @@ class OPTPolicy(Policy): else: module = self.model.model.decoder - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index e183b0632..f5f701dc0 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from functools import partial from typing import Callable, Dict, List, Tuple @@ -241,9 +243,8 @@ class T5BasePolicy(Policy): def postprocess(self): return self.model - @staticmethod def distribute_t5_layers( - num_encoder_layers: int, num_decoder_layers: int, num_stages: int + self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int ) -> Tuple[List[int], int]: """ Distribute t5 layers into stages when pipeline parallel is used. @@ -261,7 +262,7 @@ class T5BasePolicy(Policy): # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + return self.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -272,22 +273,21 @@ class T5BasePolicy(Policy): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages - @staticmethod def get_t5_stage_index( - layers_per_stage: List[int], stage: int, decoder_starting_stage: int + self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ if stage < decoder_starting_stage: - return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" @@ -302,12 +302,10 @@ class T5BasePolicy(Policy): num_decoder_layers = len(decoder.block) if decoder else 0 held_layers = [] - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - start_idx, end_idx = T5BasePolicy.get_t5_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) if stage_manager.stage < decoder_starting_stage: # current stage is in t5's encoder @@ -343,10 +341,10 @@ class T5BasePolicy(Policy): num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + stage_index = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) method_replacement = { "forward": partial( @@ -386,7 +384,7 @@ class T5ModelPolicy(T5BasePolicy): module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + _, decoder_starting_stage = self.distribute_t5_layers( len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages ) @@ -434,7 +432,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + _, decoder_starting_stage = self.distribute_t5_layers( len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages ) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 584d4e265..b0f224e22 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -149,8 +149,8 @@ class ViTPolicy(Policy): else: module = self.model.vit - layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 14e1e3e0f..480a4beea 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -292,9 +292,8 @@ class WhisperPolicy(Policy): def postprocess(self): return self.model - @staticmethod def distribute_whisper_layers( - num_encoder_layers: int, num_decoder_layers: int, num_stages: int + self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int ) -> Tuple[List[int], int]: """ Distribute whisper layers into stages when pipeline parallel is used. @@ -312,7 +311,7 @@ class WhisperPolicy(Policy): # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + return self.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -323,22 +322,21 @@ class WhisperPolicy(Policy): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages - @staticmethod def get_whisper_stage_index( - layers_per_stage: List[int], stage: int, decoder_starting_stage: int + self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ if stage < decoder_starting_stage: - return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return Policy.get_stage_index( + return self.get_stage_index( layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage, ) @@ -369,12 +367,10 @@ class WhisperPolicy(Policy): num_decoder_layers = 0 held_layers = [] - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - start_idx, end_idx = WhisperPolicy.get_whisper_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) if stage_manager.stage < decoder_starting_stage: # current stage is in whisper's encoder @@ -424,12 +420,10 @@ class WhisperPolicy(Policy): else: num_decoder_layers = 0 - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - stage_index = WhisperPolicy.get_whisper_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + stage_index = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) method_replacement = { "forward": partial( @@ -511,7 +505,7 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + _, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) shared_params = [] diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 17e7aa46c..66a42e017 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -98,11 +98,11 @@ class OpenMoePolicy(Policy): module = self.model.model layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return @@ -126,12 +126,9 @@ class OpenMoePolicy(Policy): held_layers.append(module.norm) return held_layers - - @staticmethod - def distribute_layers(num_layers: int, num_stages: int) -> List[int]: - """Divide layers into stages - """ + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """Divide layers into stages""" if num_layers == 24 and num_stages == 4: return [7, 7, 7, 3] elif num_layers == 24 and num_stages == 2: @@ -142,7 +139,7 @@ class OpenMoePolicy(Policy): return [8, 4] else: print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy") - return Policy.distribute_layers(num_layers, num_stages) + return super().distribute_layers(num_layers, num_stages) class OpenMoeModelPolicy(OpenMoePolicy): diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 61558c003..52cb8c46e 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -83,7 +83,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ @parameterize("init_method", ["none", "lazy"]) def check_3d_plugin(init_method: str = "none", early_stop: bool = True): - """check gemini plugin over model zoo + """check hybrid plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. @@ -260,7 +260,7 @@ def run_grad_acc_test(test_args): origin_model, origin_optimizer, dataloader=dataloader ) for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): - assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) def run_dist(rank, world_size, port, early_stop: bool = True): @@ -271,9 +271,9 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() -def test_gemini_plugin(early_stop: bool = True): +def test_3d_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) if __name__ == "__main__": - test_gemini_plugin(early_stop=False) \ No newline at end of file + test_3d_plugin(early_stop=False) diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 3723c9c10..4ba67225f 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -10,9 +10,12 @@ def test_t5_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + policy = T5BasePolicy() for i in range(num_test_cases): - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + _, decoder_starting_stage = policy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage @@ -32,14 +35,15 @@ def test_t5_pipeline_layers(): } for i in range(num_test_cases): - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + policy = T5BasePolicy() + layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) for stage in range(test_dict["num_stages"][i]): start_idx, end_idx = test_dict["layers_per_stage"][i][stage] - predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index( - layers_per_stage, stage, decoder_starting_stage - ) + predicted_start, predicted_end = policy.get_t5_stage_index(layers_per_stage, stage, decoder_starting_stage) assert start_idx == predicted_start assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index f6be8f6fe..0500e46e8 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -10,9 +10,12 @@ def test_whisper_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + policy = WhisperPolicy() for i in range(num_test_cases): - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + _, decoder_starting_stage = policy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage @@ -31,14 +34,17 @@ def test_whisper_pipeline_layers(): ], } + policy = WhisperPolicy() for i in range(num_test_cases): - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) for stage in range(test_dict["num_stages"][i]): start_idx, end_idx = test_dict["layers_per_stage"][i][stage] - predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index( + predicted_start, predicted_end = policy.get_whisper_stage_index( layers_per_stage, stage, decoder_starting_stage ) assert start_idx == predicted_start diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py index f594a80a4..414157c22 100644 --- a/tests/test_shardformer/test_layer/test_dist_crossentropy.py +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -38,9 +38,10 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index): org_loss, dist_loss, atol=1e-5 ), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" - target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank] - assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}" + assert torch.allclose( + target_grad, dist_pred.grad + ), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}" @pytest.mark.dist From 36c4bb2893e73022b1060bd6ad5c0685869e5465 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Thu, 28 Mar 2024 16:30:04 +0800 Subject: [PATCH 35/43] [Fix] Grok-1 use tokenizer from the same pretrained path (#5532) * [fix] use tokenizer from the same pretrained path * trust remote code --- examples/language/grok-1/inference.py | 9 ++++----- examples/language/grok-1/inference_tp.py | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/language/grok-1/inference.py b/examples/language/grok-1/inference.py index a73820af9..faef7ae9d 100644 --- a/examples/language/grok-1/inference.py +++ b/examples/language/grok-1/inference.py @@ -1,7 +1,7 @@ import time import torch -from transformers import AutoModelForCausalLM, LlamaTokenizerFast +from transformers import AutoModelForCausalLM, AutoTokenizer from utils import get_defualt_parser, inference, print_output if __name__ == "__main__": @@ -9,6 +9,9 @@ if __name__ == "__main__": args = parser.parse_args() start = time.time() torch.set_default_dtype(torch.bfloat16) + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( args.pretrained, trust_remote_code=True, @@ -18,10 +21,6 @@ if __name__ == "__main__": model.eval() init_time = time.time() - start - # A transformers-compatible version of the grok-1 tokenizer by Xenova - # https://huggingface.co/Xenova/grok-1-tokenizer - tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer") - for text in args.text: output = inference( model, diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py index 604de1487..cf05880dc 100644 --- a/examples/language/grok-1/inference_tp.py +++ b/examples/language/grok-1/inference_tp.py @@ -2,7 +2,7 @@ import time import torch from grok1_policy import Grok1ForCausalLMPolicy -from transformers import AutoModelForCausalLM, LlamaTokenizerFast +from transformers import AutoModelForCausalLM, AutoTokenizer from utils import get_defualt_parser, inference, print_output import colossalai @@ -27,6 +27,9 @@ if __name__ == "__main__": ) booster = Booster(plugin=plugin) torch.set_default_dtype(torch.bfloat16) + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True) + with LazyInitContext(default_device=get_current_device()): model = AutoModelForCausalLM.from_pretrained( args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16 @@ -35,10 +38,6 @@ if __name__ == "__main__": model.eval() init_time = time.time() - start - # A transformers-compatible version of the grok-1 tokenizer by Xenova - # https://huggingface.co/Xenova/grok-1-tokenizer - tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer") - for text in args.text: output = inference( model.unwrap(), From df5e9c53cf23d44656470cc319ee0b470c40712f Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Fri, 29 Mar 2024 14:12:29 +0800 Subject: [PATCH 36/43] [ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li --- .github/workflows/run_chatgpt_examples.yml | 29 +- .github/workflows/run_chatgpt_unit_tests.yml | 10 +- .gitignore | 4 + applications/Chat/benchmarks/README.md | 38 - .../benchmarks/benchmark_opt_lora_dummy.py | 208 --- applications/Chat/coati/dataset/__init__.py | 13 - .../Chat/coati/dataset/conversation.py | 89 - .../Chat/coati/dataset/prompt_dataset.py | 45 - .../Chat/coati/dataset/reward_dataset.py | 88 - .../Chat/coati/dataset/sft_dataset.py | 200 --- applications/Chat/coati/dataset/utils.py | 22 - .../Chat/coati/experience_maker/naive.py | 71 - applications/Chat/coati/kernels/__init__.py | 6 - applications/Chat/coati/kernels/opt_attn.py | 90 - applications/Chat/coati/kernels/wrapper.py | 18 - applications/Chat/coati/models/__init__.py | 15 - .../Chat/coati/models/base/__init__.py | 27 - applications/Chat/coati/models/base/actor.py | 33 - applications/Chat/coati/models/base/critic.py | 34 - .../Chat/coati/models/base/reward_model.py | 46 - .../Chat/coati/models/bloom/__init__.py | 5 - .../Chat/coati/models/bloom/bloom_actor.py | 36 - .../Chat/coati/models/bloom/bloom_critic.py | 36 - .../Chat/coati/models/bloom/bloom_rm.py | 36 - .../Chat/coati/models/chatglm/__init__.py | 3 - .../coati/models/chatglm/chatglm_actor.py | 31 - .../coati/models/chatglm/chatglm_tokenizer.py | 442 ----- .../models/chatglm/configuration_chatglm.py | 101 -- .../coati/models/chatglm/modeling_chatglm.py | 1477 ----------------- applications/Chat/coati/models/generation.py | 152 -- .../Chat/coati/models/gpt/__init__.py | 5 - .../Chat/coati/models/gpt/gpt_actor.py | 38 - .../Chat/coati/models/gpt/gpt_critic.py | 37 - applications/Chat/coati/models/gpt/gpt_rm.py | 37 - .../Chat/coati/models/llama/__init__.py | 5 - .../Chat/coati/models/llama/llama_actor.py | 38 - .../Chat/coati/models/llama/llama_critic.py | 36 - .../Chat/coati/models/llama/llama_rm.py | 37 - applications/Chat/coati/models/loss.py | 97 -- .../Chat/coati/models/opt/__init__.py | 5 - .../Chat/coati/models/opt/opt_actor.py | 37 - .../Chat/coati/models/opt/opt_critic.py | 37 - applications/Chat/coati/models/opt/opt_rm.py | 36 - applications/Chat/coati/models/utils.py | 69 - applications/Chat/coati/trainer/__init__.py | 6 - .../Chat/coati/trainer/callbacks/__init__.py | 5 - .../trainer/callbacks/save_checkpoint.py | 76 - applications/Chat/coati/trainer/ppo.py | 202 --- applications/Chat/coati/trainer/rm.py | 123 -- applications/Chat/coati/trainer/sft.py | 130 -- .../Chat/coati/trainer/strategies/__init__.py | 5 - .../Chat/coati/trainer/strategies/base.py | 137 -- .../coati/trainer/strategies/colossalai.py | 209 --- .../Chat/coati/trainer/strategies/ddp.py | 136 -- .../Chat/coati/trainer/strategies/sampler.py | 31 - applications/Chat/coati/trainer/utils.py | 50 - applications/Chat/examples/README.md | 409 ----- applications/Chat/examples/download_model.py | 79 - .../examples/generate_conversation_dataset.py | 82 - .../Chat/examples/generate_prompt_dataset.py | 27 - applications/Chat/examples/inference.py | 73 - applications/Chat/examples/train_prompts.py | 249 --- applications/Chat/examples/train_prompts.sh | 25 - .../Chat/examples/train_reward_model.py | 208 --- applications/Chat/examples/train_rm.sh | 25 - applications/Chat/examples/train_sft.py | 221 --- applications/Chat/examples/train_sft.sh | 28 - applications/Chat/inference/benchmark.py | 141 -- .../Chat/inference/tests/test_chat_prompt.py | 61 - applications/Chat/inference/utils.py | 209 --- applications/Chat/requirements-test.txt | 2 - applications/Chat/requirements.txt | 14 - applications/Chat/tests/test_benchmarks.sh | 33 - applications/Chat/tests/test_checkpoint.py | 91 - applications/Chat/tests/test_dataset.py | 241 --- applications/Chat/tests/test_experience.py | 130 -- applications/Chat/tests/test_inference.sh | 11 - applications/Chat/tests/test_models.py | 245 --- applications/Chat/tests/test_train.sh | 233 --- .../{Chat => ColossalChat}/.gitignore | 13 +- applications/{Chat => ColossalChat}/LICENSE | 0 applications/{Chat => ColossalChat}/README.md | 348 ++-- applications/ColossalChat/benchmarks/Opt.json | 17 + .../ColossalChat/benchmarks/README.md | 37 + .../benchmark_memory_consumption.txt | 4 + .../benchmark_performance_summarization.txt | 16 + .../ColossalChat/benchmarks/benchmark_ppo.py | 523 ++++++ .../ColossalChat/benchmarks/benchmark_ppo.sh | 119 ++ .../benchmarks/data_preparation.sh | 16 + .../benchmarks/ray/1mmt_dummy.py | 0 .../benchmarks/ray/mmmt_dummy.py | 0 .../{Chat => ColossalChat}/coati/__init__.py | 0 .../ColossalChat/coati/dataset/__init__.py | 26 + .../coati/dataset/conversation.py | 143 ++ .../ColossalChat/coati/dataset/loader.py | 383 +++++ .../coati/dataset/tokenization_utils.py | 383 +++++ .../ColossalChat/coati/dataset/utils.py | 138 ++ .../coati/experience_buffer/__init__.py | 0 .../coati/experience_buffer/base.py | 0 .../coati/experience_buffer/naive.py | 13 +- .../coati/experience_buffer/utils.py | 5 +- .../coati/experience_maker/__init__.py | 0 .../coati/experience_maker/base.py | 26 +- .../coati/experience_maker/naive.py | 180 ++ .../ColossalChat/coati/models/__init__.py | 24 + .../ColossalChat/coati/models/base.py | 58 + .../ColossalChat/coati/models/critic.py | 34 + .../ColossalChat/coati/models/generation.py | 428 +++++ .../coati/models/lora.py | 52 +- .../ColossalChat/coati/models/loss.py | 169 ++ .../ColossalChat/coati/models/reward_model.py | 38 + .../ColossalChat/coati/models/utils.py | 137 ++ .../coati/quant/__init__.py | 0 .../coati/quant/llama_gptq/__init__.py | 0 .../coati/quant/llama_gptq/loader.py | 0 .../coati/quant/llama_gptq/model_utils.py | 0 .../coati/quant/llama_gptq/quant.py | 0 .../coati/quant/utils.py | 0 .../coati/ray/README.md | 0 .../coati/ray/__init__.py | 0 .../coati/ray/callbacks/__init__.py | 0 .../coati/ray/callbacks/base.py | 0 .../ray/callbacks/performance_evaluator.py | 0 .../coati/ray/detached_replay_buffer.py | 0 .../coati/ray/detached_trainer_base.py | 0 .../coati/ray/detached_trainer_ppo.py | 0 .../coati/ray/experience_maker_holder.py | 0 .../coati/ray/lora_constructor.py | 0 .../{Chat => ColossalChat}/coati/ray/utils.py | 4 +- .../ColossalChat/coati/trainer/__init__.py | 7 + .../coati/trainer/base.py | 56 +- .../coati/trainer/callbacks/__init__.py | 4 + .../coati/trainer/callbacks/base.py | 0 .../callbacks/performance_evaluator.py | 16 +- .../ColossalChat/coati/trainer/dpo.py | 336 ++++ .../ColossalChat/coati/trainer/ppo.py | 403 +++++ applications/ColossalChat/coati/trainer/rm.py | 242 +++ .../ColossalChat/coati/trainer/sft.py | 170 ++ .../ColossalChat/coati/trainer/utils.py | 113 ++ .../ColossalChat/coati/utils/__init__.py | 4 + .../coati/utils/accumulative_meter.py | 69 + .../ColossalChat/coati/utils/ckpt_io.py | 93 ++ .../config/conversation_template/Qwen.json | 7 + .../config/conversation_template/Vicuna.json | 7 + .../config/conversation_template/Yi.json | 7 + .../conversation_template/chatGLM2.json | 7 + .../colossal-llama2.json | 7 + .../config/conversation_template/llama2.json | 7 + .../config/conversation_template/mistral.json | 7 + .../config/conversation_template/zephyr.json | 7 + applications/ColossalChat/examples/README.md | 565 +++++++ .../examples/community/README.md | 0 .../examples/community/peft/README.md | 0 .../examples/community/peft/easy_dataset.py | 0 .../examples/community/peft/easy_models.py | 0 .../community/peft/train_peft_prompts.py | 4 +- .../examples/community/peft/train_peft_sft.py | 0 .../examples/community/ray/README.md | 0 .../examples/community/ray/ray_job_script.py | 0 .../community/ray/train_prompts_on_ray.py | 0 .../prepare_dataset.py | 268 +++ .../prepare_preference_dataset.sh | 13 + .../prepare_prompt_dataset.sh | 13 + .../prepare_sft_dataset.sh | 13 + .../ColossalChat/examples/inference/chatio.py | 168 ++ .../examples/inference/inference.py | 195 +++ .../examples/inference/web_chatbot}/README.md | 0 .../inference/web_chatbot}/locustfile.py | 0 .../inference/web_chatbot}/requirements.txt | 0 .../examples/inference/web_chatbot}/server.py | 46 +- .../examples/inference/web_chatbot/utils.py | 78 + .../examples/ray/1mmt_prompt.py | 0 .../examples/ray/mmmt_prompt.py | 0 .../examples/ray/requirements.txt | 0 .../examples/ray/test_ci.sh | 0 .../examples/requirements.txt | 3 +- .../examples/training_scripts/hostfile | 1 + .../examples/training_scripts/train_dpo.py | 326 ++++ .../examples/training_scripts/train_dpo.sh | 62 + .../examples/training_scripts/train_ppo.py | 506 ++++++ .../examples/training_scripts/train_ppo.sh | 82 + .../examples/training_scripts/train_rm.py | 342 ++++ .../examples/training_scripts/train_rm.sh | 61 + .../examples/training_scripts/train_sft.py | 311 ++++ .../examples/training_scripts/train_sft.sh | 59 + .../{Chat => ColossalChat}/pytest.ini | 0 applications/ColossalChat/requirements.txt | 24 + applications/{Chat => ColossalChat}/setup.py | 2 +- .../{Chat => ColossalChat}/tests/__init__.py | 0 .../generate_dummy_datasets_for_testing.py | 72 + applications/ColossalChat/tests/llama.json | 8 + applications/ColossalChat/tests/opt.json | 17 + .../tests/test_data/dpo/test_dpo_data.jsonl | 1 + .../tests/test_data/sft/test_sft_data.jsonl | 1 + .../tests/test_data_preparation.sh | 260 +++ applications/ColossalChat/tests/test_lora.py | 69 + .../ColossalChat/tests/test_templating.sh | 97 ++ applications/ColossalChat/tests/test_train.sh | 397 +++++ .../ColossalChat/tests/verify_chat_data.py | 64 + .../{Chat => ColossalChat}/version.txt | 0 200 files changed, 8848 insertions(+), 8049 deletions(-) delete mode 100644 applications/Chat/benchmarks/README.md delete mode 100644 applications/Chat/benchmarks/benchmark_opt_lora_dummy.py delete mode 100644 applications/Chat/coati/dataset/__init__.py delete mode 100644 applications/Chat/coati/dataset/conversation.py delete mode 100644 applications/Chat/coati/dataset/prompt_dataset.py delete mode 100644 applications/Chat/coati/dataset/reward_dataset.py delete mode 100644 applications/Chat/coati/dataset/sft_dataset.py delete mode 100644 applications/Chat/coati/dataset/utils.py delete mode 100644 applications/Chat/coati/experience_maker/naive.py delete mode 100644 applications/Chat/coati/kernels/__init__.py delete mode 100644 applications/Chat/coati/kernels/opt_attn.py delete mode 100644 applications/Chat/coati/kernels/wrapper.py delete mode 100644 applications/Chat/coati/models/__init__.py delete mode 100644 applications/Chat/coati/models/base/__init__.py delete mode 100644 applications/Chat/coati/models/base/actor.py delete mode 100644 applications/Chat/coati/models/base/critic.py delete mode 100644 applications/Chat/coati/models/base/reward_model.py delete mode 100644 applications/Chat/coati/models/bloom/__init__.py delete mode 100644 applications/Chat/coati/models/bloom/bloom_actor.py delete mode 100644 applications/Chat/coati/models/bloom/bloom_critic.py delete mode 100644 applications/Chat/coati/models/bloom/bloom_rm.py delete mode 100644 applications/Chat/coati/models/chatglm/__init__.py delete mode 100644 applications/Chat/coati/models/chatglm/chatglm_actor.py delete mode 100644 applications/Chat/coati/models/chatglm/chatglm_tokenizer.py delete mode 100644 applications/Chat/coati/models/chatglm/configuration_chatglm.py delete mode 100644 applications/Chat/coati/models/chatglm/modeling_chatglm.py delete mode 100644 applications/Chat/coati/models/generation.py delete mode 100644 applications/Chat/coati/models/gpt/__init__.py delete mode 100644 applications/Chat/coati/models/gpt/gpt_actor.py delete mode 100644 applications/Chat/coati/models/gpt/gpt_critic.py delete mode 100644 applications/Chat/coati/models/gpt/gpt_rm.py delete mode 100644 applications/Chat/coati/models/llama/__init__.py delete mode 100644 applications/Chat/coati/models/llama/llama_actor.py delete mode 100644 applications/Chat/coati/models/llama/llama_critic.py delete mode 100644 applications/Chat/coati/models/llama/llama_rm.py delete mode 100644 applications/Chat/coati/models/loss.py delete mode 100644 applications/Chat/coati/models/opt/__init__.py delete mode 100644 applications/Chat/coati/models/opt/opt_actor.py delete mode 100644 applications/Chat/coati/models/opt/opt_critic.py delete mode 100644 applications/Chat/coati/models/opt/opt_rm.py delete mode 100644 applications/Chat/coati/models/utils.py delete mode 100644 applications/Chat/coati/trainer/__init__.py delete mode 100644 applications/Chat/coati/trainer/callbacks/__init__.py delete mode 100644 applications/Chat/coati/trainer/callbacks/save_checkpoint.py delete mode 100644 applications/Chat/coati/trainer/ppo.py delete mode 100644 applications/Chat/coati/trainer/rm.py delete mode 100644 applications/Chat/coati/trainer/sft.py delete mode 100644 applications/Chat/coati/trainer/strategies/__init__.py delete mode 100644 applications/Chat/coati/trainer/strategies/base.py delete mode 100644 applications/Chat/coati/trainer/strategies/colossalai.py delete mode 100644 applications/Chat/coati/trainer/strategies/ddp.py delete mode 100644 applications/Chat/coati/trainer/strategies/sampler.py delete mode 100644 applications/Chat/coati/trainer/utils.py delete mode 100644 applications/Chat/examples/README.md delete mode 100644 applications/Chat/examples/download_model.py delete mode 100644 applications/Chat/examples/generate_conversation_dataset.py delete mode 100644 applications/Chat/examples/generate_prompt_dataset.py delete mode 100644 applications/Chat/examples/inference.py delete mode 100644 applications/Chat/examples/train_prompts.py delete mode 100755 applications/Chat/examples/train_prompts.sh delete mode 100644 applications/Chat/examples/train_reward_model.py delete mode 100755 applications/Chat/examples/train_rm.sh delete mode 100644 applications/Chat/examples/train_sft.py delete mode 100755 applications/Chat/examples/train_sft.sh delete mode 100644 applications/Chat/inference/benchmark.py delete mode 100644 applications/Chat/inference/tests/test_chat_prompt.py delete mode 100644 applications/Chat/inference/utils.py delete mode 100644 applications/Chat/requirements-test.txt delete mode 100644 applications/Chat/requirements.txt delete mode 100755 applications/Chat/tests/test_benchmarks.sh delete mode 100644 applications/Chat/tests/test_checkpoint.py delete mode 100644 applications/Chat/tests/test_dataset.py delete mode 100644 applications/Chat/tests/test_experience.py delete mode 100755 applications/Chat/tests/test_inference.sh delete mode 100644 applications/Chat/tests/test_models.py delete mode 100755 applications/Chat/tests/test_train.sh rename applications/{Chat => ColossalChat}/.gitignore (87%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/LICENSE (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/README.md (59%) mode change 100644 => 100755 create mode 100644 applications/ColossalChat/benchmarks/Opt.json create mode 100755 applications/ColossalChat/benchmarks/README.md create mode 100644 applications/ColossalChat/benchmarks/benchmark_memory_consumption.txt create mode 100644 applications/ColossalChat/benchmarks/benchmark_performance_summarization.txt create mode 100644 applications/ColossalChat/benchmarks/benchmark_ppo.py create mode 100755 applications/ColossalChat/benchmarks/benchmark_ppo.sh create mode 100755 applications/ColossalChat/benchmarks/data_preparation.sh rename applications/{Chat => ColossalChat}/benchmarks/ray/1mmt_dummy.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/benchmarks/ray/mmmt_dummy.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/__init__.py (100%) mode change 100644 => 100755 create mode 100755 applications/ColossalChat/coati/dataset/__init__.py create mode 100755 applications/ColossalChat/coati/dataset/conversation.py create mode 100755 applications/ColossalChat/coati/dataset/loader.py create mode 100755 applications/ColossalChat/coati/dataset/tokenization_utils.py create mode 100755 applications/ColossalChat/coati/dataset/utils.py rename applications/{Chat => ColossalChat}/coati/experience_buffer/__init__.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/experience_buffer/base.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/experience_buffer/naive.py (86%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/experience_buffer/utils.py (94%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/experience_maker/__init__.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/experience_maker/base.py (74%) mode change 100644 => 100755 create mode 100755 applications/ColossalChat/coati/experience_maker/naive.py create mode 100755 applications/ColossalChat/coati/models/__init__.py create mode 100755 applications/ColossalChat/coati/models/base.py create mode 100755 applications/ColossalChat/coati/models/critic.py create mode 100755 applications/ColossalChat/coati/models/generation.py rename applications/{Chat => ColossalChat}/coati/models/lora.py (85%) mode change 100644 => 100755 create mode 100755 applications/ColossalChat/coati/models/loss.py create mode 100755 applications/ColossalChat/coati/models/reward_model.py create mode 100755 applications/ColossalChat/coati/models/utils.py rename applications/{Chat => ColossalChat}/coati/quant/__init__.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/quant/llama_gptq/__init__.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/quant/llama_gptq/loader.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/quant/llama_gptq/model_utils.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/quant/llama_gptq/quant.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/quant/utils.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/README.md (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/__init__.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/callbacks/__init__.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/callbacks/base.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/callbacks/performance_evaluator.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/detached_replay_buffer.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/detached_trainer_base.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/detached_trainer_ppo.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/experience_maker_holder.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/lora_constructor.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/coati/ray/utils.py (97%) mode change 100644 => 100755 create mode 100755 applications/ColossalChat/coati/trainer/__init__.py rename applications/{Chat => ColossalChat}/coati/trainer/base.py (80%) mode change 100644 => 100755 create mode 100644 applications/ColossalChat/coati/trainer/callbacks/__init__.py rename applications/{Chat => ColossalChat}/coati/trainer/callbacks/base.py (100%) rename applications/{Chat => ColossalChat}/coati/trainer/callbacks/performance_evaluator.py (92%) create mode 100755 applications/ColossalChat/coati/trainer/dpo.py create mode 100755 applications/ColossalChat/coati/trainer/ppo.py create mode 100755 applications/ColossalChat/coati/trainer/rm.py create mode 100755 applications/ColossalChat/coati/trainer/sft.py create mode 100755 applications/ColossalChat/coati/trainer/utils.py create mode 100755 applications/ColossalChat/coati/utils/__init__.py create mode 100755 applications/ColossalChat/coati/utils/accumulative_meter.py create mode 100755 applications/ColossalChat/coati/utils/ckpt_io.py create mode 100644 applications/ColossalChat/config/conversation_template/Qwen.json create mode 100644 applications/ColossalChat/config/conversation_template/Vicuna.json create mode 100644 applications/ColossalChat/config/conversation_template/Yi.json create mode 100644 applications/ColossalChat/config/conversation_template/chatGLM2.json create mode 100644 applications/ColossalChat/config/conversation_template/colossal-llama2.json create mode 100644 applications/ColossalChat/config/conversation_template/llama2.json create mode 100644 applications/ColossalChat/config/conversation_template/mistral.json create mode 100644 applications/ColossalChat/config/conversation_template/zephyr.json create mode 100755 applications/ColossalChat/examples/README.md rename applications/{Chat => ColossalChat}/examples/community/README.md (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/community/peft/README.md (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/community/peft/easy_dataset.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/community/peft/easy_models.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/community/peft/train_peft_prompts.py (98%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/community/peft/train_peft_sft.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/community/ray/README.md (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/community/ray/ray_job_script.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/community/ray/train_prompts_on_ray.py (100%) mode change 100644 => 100755 create mode 100644 applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py create mode 100755 applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh create mode 100755 applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh create mode 100755 applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh create mode 100755 applications/ColossalChat/examples/inference/chatio.py create mode 100755 applications/ColossalChat/examples/inference/inference.py rename applications/{Chat/inference => ColossalChat/examples/inference/web_chatbot}/README.md (100%) mode change 100644 => 100755 rename applications/{Chat/inference => ColossalChat/examples/inference/web_chatbot}/locustfile.py (100%) mode change 100644 => 100755 rename applications/{Chat/inference => ColossalChat/examples/inference/web_chatbot}/requirements.txt (100%) mode change 100644 => 100755 rename applications/{Chat/inference => ColossalChat/examples/inference/web_chatbot}/server.py (79%) mode change 100644 => 100755 create mode 100755 applications/ColossalChat/examples/inference/web_chatbot/utils.py rename applications/{Chat => ColossalChat}/examples/ray/1mmt_prompt.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/ray/mmmt_prompt.py (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/ray/requirements.txt (100%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/examples/ray/test_ci.sh (100%) rename applications/{Chat => ColossalChat}/examples/requirements.txt (51%) create mode 100755 applications/ColossalChat/examples/training_scripts/hostfile create mode 100755 applications/ColossalChat/examples/training_scripts/train_dpo.py create mode 100755 applications/ColossalChat/examples/training_scripts/train_dpo.sh create mode 100755 applications/ColossalChat/examples/training_scripts/train_ppo.py create mode 100755 applications/ColossalChat/examples/training_scripts/train_ppo.sh create mode 100755 applications/ColossalChat/examples/training_scripts/train_rm.py create mode 100755 applications/ColossalChat/examples/training_scripts/train_rm.sh create mode 100755 applications/ColossalChat/examples/training_scripts/train_sft.py create mode 100755 applications/ColossalChat/examples/training_scripts/train_sft.sh rename applications/{Chat => ColossalChat}/pytest.ini (100%) mode change 100644 => 100755 create mode 100755 applications/ColossalChat/requirements.txt rename applications/{Chat => ColossalChat}/setup.py (97%) mode change 100644 => 100755 rename applications/{Chat => ColossalChat}/tests/__init__.py (100%) mode change 100644 => 100755 create mode 100644 applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py create mode 100644 applications/ColossalChat/tests/llama.json create mode 100644 applications/ColossalChat/tests/opt.json create mode 100644 applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl create mode 100644 applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl create mode 100755 applications/ColossalChat/tests/test_data_preparation.sh create mode 100755 applications/ColossalChat/tests/test_lora.py create mode 100755 applications/ColossalChat/tests/test_templating.sh create mode 100755 applications/ColossalChat/tests/test_train.sh create mode 100644 applications/ColossalChat/tests/verify_chat_data.py rename applications/{Chat => ColossalChat}/version.txt (100%) mode change 100644 => 100755 diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index bb0ceb4a8..ba997f144 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -19,8 +19,8 @@ jobs: runs-on: [self-hosted, gpu] container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 - options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb - timeout-minutes: 30 + options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data --shm-size=10.24gb + timeout-minutes: 60 defaults: run: shell: bash @@ -28,26 +28,35 @@ jobs: - name: Checkout ColossalAI uses: actions/checkout@v2 + - name: Install Colossal-AI + run: | + BUILD_EXT=1 pip install -v -e . + - name: Install ChatGPT run: | - cd applications/Chat + cd applications/ColossalChat pip install -v . + export BUILD_EXT=1 pip install -r examples/requirements.txt - name: Install Transformers run: | - pip install transformers==4.30.2 + pip install transformers==4.34.1 - name: Execute Examples run: | - cd applications/Chat + cd applications/ColossalChat rm -rf ~/.cache/colossalai - ./tests/test_inference.sh - ./tests/test_benchmarks.sh + mkdir models + mkdir sft_data + mkdir prompt_data + mkdir preference_data + ./tests/test_data_preparation.sh ./tests/test_train.sh env: NCCL_SHM_DISABLE: 1 MAX_JOBS: 8 - SFT_DATASET: /data/scratch/github_actions/chat/data.json - PROMPT_DATASET: /data/scratch/github_actions/chat/prompts_en.jsonl - PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json + PRETRAINED_MODEL_PATH: ./models + SFT_DATASET: ./sft_data + PROMPT_DATASET: ./prompt_data + PREFERENCE_DATASET: ./preference_data diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml index 7986889e0..1d8a53e4f 100644 --- a/.github/workflows/run_chatgpt_unit_tests.yml +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -21,7 +21,7 @@ jobs: runs-on: [self-hosted, gpu] container: image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 - options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt + options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data timeout-minutes: 30 defaults: run: @@ -32,15 +32,17 @@ jobs: - name: Install ChatGPT run: | - cd applications/Chat + cd applications/ColossalChat pip install -v . - pip install -r requirements-test.txt + pip install pytest - name: Execute Unit Testing run: | - cd applications/Chat + cd applications/ColossalChat rm -rf ~/.cache/colossalai pytest tests/ + cd ./tests + ./test_templating.sh env: NCCL_SHM_DISABLE: 1 MAX_JOBS: 8 diff --git a/.gitignore b/.gitignore index 81113fa99..e539a2612 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,7 @@ coverage.xml # ignore testmon and coverage files .coverage .testmondata* + +# log, test files - ColossalChat +applications/ColossalChat/logs +applications/ColossalChat/tests/logs \ No newline at end of file diff --git a/applications/Chat/benchmarks/README.md b/applications/Chat/benchmarks/README.md deleted file mode 100644 index c13f34858..000000000 --- a/applications/Chat/benchmarks/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Benchmarks - -## Benchmark OPT with LoRA on dummy prompt data - -We provide various OPT models (string in parentheses is the corresponding model name used in this script): - -- OPT-125M (125m) -- OPT-350M (350m) -- OPT-700M (700m) -- OPT-1.3B (1.3b) -- OPT-2.7B (2.7b) -- OPT-3.5B (3.5b) -- OPT-5.5B (5.5b) -- OPT-6.7B (6.7b) -- OPT-10B (10b) -- OPT-13B (13b) - -We also provide various training strategies: - -- ddp: torch DDP -- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3 -- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload -- colossalai_zero2: ColossalAI zero2 -- colossalai_zero2_cpu: ColossalAI zero2-offload -- colossalai_zero1: ColossalAI zero1 -- colossalai_zero1_cpu: ColossalAI zero1-offload - -We only support `torchrun` to launch now. E.g. - -```bash -# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size -torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py \ - --model 125m --critic_model 125m --strategy ddp \ - --experience_batch_size 1 --train_batch_size 1 --lora_rank 0 -# run Actor (OPT-1.3B) and Critic (OPT-350M) with lora_rank=4 on single-node 4-GPU -torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py \ - --model 1.3b --critic_model 350m --strategy colossalai_zero2 --lora_rank 4 -``` diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py deleted file mode 100644 index 0d0e2a7d3..000000000 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ /dev/null @@ -1,208 +0,0 @@ -import argparse -from copy import deepcopy - -import torch -import torch.distributed as dist -import torch.nn as nn -from coati.models.base import RewardModel -from coati.models.opt import OPTActor, OPTCritic -from coati.trainer import PPOTrainer -from coati.trainer.callbacks import PerformanceEvaluator -from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy -from torch.optim import Adam -from torch.utils.data import DataLoader -from transformers import AutoTokenizer -from transformers.models.opt.configuration_opt import OPTConfig - -from colossalai.nn.optimizer import HybridAdam - - -def get_model_numel(model: nn.Module, strategy: Strategy) -> int: - numel = sum(p.numel() for p in model.parameters()) - if isinstance(strategy, GeminiStrategy) and strategy.shard_init: - numel *= dist.get_world_size() - return numel - - -def preprocess_batch(samples) -> dict: - input_ids = torch.stack(samples) - attention_mask = torch.ones_like(input_ids, dtype=torch.long) - return {"input_ids": input_ids, "attention_mask": attention_mask} - - -def print_rank_0(*args, **kwargs) -> None: - if dist.get_rank() == 0: - print(*args, **kwargs) - - -def print_model_numel(model_dict: dict) -> None: - B = 1024**3 - M = 1024**2 - K = 1024 - outputs = "" - for name, numel in model_dict.items(): - outputs += f"{name}: " - if numel >= B: - outputs += f"{numel / B:.2f} B\n" - elif numel >= M: - outputs += f"{numel / M:.2f} M\n" - elif numel >= K: - outputs += f"{numel / K:.2f} K\n" - else: - outputs += f"{numel}\n" - print_rank_0(outputs) - - -def get_gpt_config(model_name: str) -> OPTConfig: - model_map = { - "125m": OPTConfig.from_pretrained("facebook/opt-125m"), - "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), - "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), - "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"), - "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"), - "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), - "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), - "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"), - "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), - "13b": OPTConfig.from_pretrained("facebook/opt-13b"), - } - try: - return model_map[model_name] - except KeyError: - raise ValueError(f'Unknown model "{model_name}"') - - -def main(args): - if args.strategy == "ddp": - strategy = DDPStrategy() - elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5) - elif args.strategy == "colossalai_gemini_cpu": - strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) - elif args.strategy == "colossalai_zero2": - strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") - elif args.strategy == "colossalai_zero2_cpu": - strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") - elif args.strategy == "colossalai_zero1": - strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda") - elif args.strategy == "colossalai_zero1_cpu": - strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu") - else: - raise ValueError(f'Unsupported strategy "{args.strategy}"') - - torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac) - - model_config = get_gpt_config(args.model) - critic_config = get_gpt_config(args.critic_model) - with strategy.model_init_context(): - actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda() - critic = OPTCritic(config=critic_config, lora_rank=args.lora_rank).cuda() - - initial_model = deepcopy(actor).cuda().half() - reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda().half() - - if args.use_kernels: - from coati.kernels import convert_to_xformer_model - - actor, critic, initial_model, reward_model = map( - convert_to_xformer_model, (actor, critic, initial_model, reward_model) - ) - - actor_numel = get_model_numel(actor, strategy) - critic_numel = get_model_numel(critic, strategy) - initial_model_numel = get_model_numel(initial_model, strategy) - reward_model_numel = get_model_numel(reward_model, strategy) - print_model_numel( - { - "Actor": actor_numel, - "Critic": critic_numel, - "Initial model": initial_model_numel, - "Reward model": reward_model_numel, - } - ) - performance_evaluator = PerformanceEvaluator( - actor_numel, - critic_numel, - initial_model_numel, - reward_model_numel, - enable_grad_checkpoint=False, - ignore_episodes=1, - ) - - if args.strategy.startswith("colossalai"): - actor_optim = HybridAdam(actor.parameters(), lr=5e-6) - critic_optim = HybridAdam(critic.parameters(), lr=5e-6) - else: - actor_optim = Adam(actor.parameters(), lr=5e-6) - critic_optim = Adam(critic.parameters(), lr=5e-6) - - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - - (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) - - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) - dataloader = DataLoader( - random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch - ) - - trainer = PPOTrainer( - strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - tokenizer=tokenizer, - ptx_coef=0, - train_batch_size=args.train_batch_size, - offload_inference_models=args.offload_inference_models, - max_length=512, - do_sample=True, - temperature=1.0, - top_k=50, - use_cache=True, - callbacks=[performance_evaluator], - ) - - trainer.fit( - prompt_dataloader=dataloader, - pretrain_dataloader=None, - num_episodes=args.num_episodes, - num_update_steps=args.num_update_steps, - num_collect_steps=args.num_collect_steps, - ) - - print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model", default="125m") - parser.add_argument("--critic_model", default="125m") - parser.add_argument( - "--strategy", - choices=[ - "ddp", - "colossalai_gemini", - "colossalai_gemini_cpu", - "colossalai_zero2", - "colossalai_zero2_cpu", - "colossalai_zero1", - "colossalai_zero1_cpu", - ], - default="ddp", - ) - parser.add_argument("--num_episodes", type=int, default=3) - parser.add_argument("--num_collect_steps", type=int, default=8) - parser.add_argument("--num_update_steps", type=int, default=1) - parser.add_argument("--train_batch_size", type=int, default=8) - parser.add_argument("--experience_batch_size", type=int, default=8) - parser.add_argument("--lora_rank", type=int, default=0) - parser.add_argument("--cuda_mem_frac", type=float, default=1.0) - parser.add_argument("--offload_inference_models", action="store_true", default=False) - parser.add_argument("--use_kernels", action="store_true", default=False) - args = parser.parse_args() - main(args) diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py deleted file mode 100644 index 599b57609..000000000 --- a/applications/Chat/coati/dataset/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .prompt_dataset import PromptDataset -from .reward_dataset import HhRlhfDataset, RmStaticDataset -from .sft_dataset import SFTDataset, SupervisedDataset -from .utils import is_rank_0 - -__all__ = [ - "RmStaticDataset", - "HhRlhfDataset", - "SFTDataset", - "SupervisedDataset", - "PromptDataset", - "is_rank_0", -] diff --git a/applications/Chat/coati/dataset/conversation.py b/applications/Chat/coati/dataset/conversation.py deleted file mode 100644 index f2180d96b..000000000 --- a/applications/Chat/coati/dataset/conversation.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2023 lm-sys@FastChat -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses -from enum import Enum, auto -from typing import List - - -class SeparatorStyle(Enum): - ADD_EOS_TOKEN = auto() - - -@dataclasses.dataclass -class Conversation: - system: str - roles: List[str] - messages: List[List[str]] - offset: int - sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN - sep: str = "
                  " - - skip_next: bool = False - - def get_prompt(self): - if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN: - ret = self.system - for role, message in self.messages: - if message: - ret += role + ": " + message + self.sep - else: - ret += role + ": " - return ret - else: - raise ValueError(f"Invalid style: {self.sep_style}") - - def append_message(self, role, message): - self.messages.append([role, message]) - - def to_gradio_chatbot(self): - ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): - if i % 2 == 0: - ret.append([msg, None]) - else: - ret[-1][-1] = msg - return ret - - def copy(self): - return Conversation( - system=self.system, - roles=self.roles, - messages=[[x, y] for x, y in self.messages], - offset=self.offset, - sep_style=self.sep_style, - sep=self.sep, - ) - - def dict(self): - return { - "system": self.system, - "roles": self.roles, - "messages": self.messages, - "offset": self.offset, - "sep": self.sep, - } - - -conv = Conversation( - system="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", - roles=("Human", "Assistant"), - messages=(), - offset=0, - sep_style=SeparatorStyle.ADD_EOS_TOKEN, - sep="
                  ", -) - -default_conversation = conv diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py deleted file mode 100644 index 17120e606..000000000 --- a/applications/Chat/coati/dataset/prompt_dataset.py +++ /dev/null @@ -1,45 +0,0 @@ -from collections import defaultdict -from typing import Dict - -import torch -import transformers -from torch.utils.data import Dataset - -from colossalai.logging import get_dist_logger - -from .utils import jload - - -class PromptDataset(Dataset): - """Dataset for supervised fine-tuning.""" - - def __init__( - self, - data_path: str, - tokenizer: transformers.PreTrainedTokenizer, - max_datasets_size: int = None, - max_length: int = 96, - ): - super(PromptDataset, self).__init__() - self.keyed_prompt = defaultdict(list) - self.logger = get_dist_logger() - self.logger.info("Loading data...") - list_data_dict = jload(data_path) - self.logger.info(f"Loaded {len(list_data_dict)} examples.") - - if max_datasets_size is not None: - self.logger.info(f"Limiting dataset to {max_datasets_size} examples.") - list_data_dict = list_data_dict[:max_datasets_size] - - instructions = [data_dict["instruction"] for data_dict in list_data_dict] - tokens = tokenizer( - instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True - ) - for k, tensor in tokens.items(): - self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind() - - def __len__(self): - return len(self.keyed_prompt["input_ids"]) - - def __getitem__(self, i) -> Dict[str, torch.Tensor]: - return {k: v[i] for k, v in self.keyed_prompt.items()} diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py deleted file mode 100644 index 3afcd7b69..000000000 --- a/applications/Chat/coati/dataset/reward_dataset.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Callable - -from torch.utils.data import Dataset -from tqdm import tqdm - -from .utils import is_rank_0 - - -# Dahoas/rm-static -class RmStaticDataset(Dataset): - """ - Dataset for reward model - - Args: - dataset: dataset for reward model - tokenizer: tokenizer for reward model - max_length: max length of input - special_token: special token at the end of sentence - """ - - def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: - super().__init__() - self.end_token = tokenizer.eos_token if special_token is None else special_token - - chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] - chosen_token = tokenizer( - chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" - ) - self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} - - reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] - reject_token = tokenizer( - reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" - ) - self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} - - def __len__(self): - length = self.chosen["input_ids"].shape[0] - return length - - def __getitem__(self, idx): - return ( - self.chosen["input_ids"][idx], - self.chosen["attention_mask"][idx], - self.reject["input_ids"][idx], - self.reject["attention_mask"][idx], - ) - - -# Anthropic/hh-rlhf -class HhRlhfDataset(Dataset): - """ - Dataset for reward model - - Args: - dataset: dataset for reward model - tokenizer: tokenizer for reward model - max_length: max length of input - special_token: special token at the end of sentence - """ - - def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: - super().__init__() - self.end_token = tokenizer.eos_token if special_token is None else special_token - - chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] - chosen_token = tokenizer( - chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" - ) - self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} - - reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] - reject_token = tokenizer( - reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" - ) - self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} - - def __len__(self): - length = self.chosen["input_ids"].shape[0] - return length - - def __getitem__(self, idx): - return ( - self.chosen["input_ids"][idx], - self.chosen["attention_mask"][idx], - self.reject["input_ids"][idx], - self.reject["attention_mask"][idx], - ) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py deleted file mode 100644 index e67e16231..000000000 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -from typing import Dict, Optional, Sequence, Tuple - -import torch -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer -from torch.utils.data import Dataset -from tqdm import tqdm -from transformers import PreTrainedTokenizer - -from colossalai.logging import get_dist_logger - -from .utils import is_rank_0, jload - -logger = get_dist_logger() - -IGNORE_INDEX = -100 -PROMPT_DICT = { - "prompt_input": ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" - ), - "prompt_no_input": ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:" - ), -} - - -def _preprocess( - sources: Sequence[str], - targets: Sequence[str], - tokenizer: PreTrainedTokenizer, - max_length: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Preprocess the data by tokenizing.""" - sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)] - sequences_token = tokenizer( - sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False - ) - - sources_token = tokenizer( - sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False - ) - - assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently" - labels = copy.deepcopy(sequences_token["input_ids"]) - for i in range(labels.shape[0]): - source_len = sources_token["attention_mask"][i].sum().item() - pad_len = max_length - sequences_token["attention_mask"][i].sum().item() - if tokenizer.padding_side == "right": - # |prompt|completion|eos|pad| - labels[i][:source_len] = IGNORE_INDEX - if pad_len>0: - labels[i][-pad_len:] = IGNORE_INDEX - elif tokenizer.padding_side == "left": - # |pad|prompt|completion|eos| - labels[i][: pad_len + source_len] = IGNORE_INDEX - else: - raise RuntimeError() - - return sequences_token["input_ids"], labels, sequences_token["attention_mask"] - - -def _preprocess_chatglm( - sources: Sequence[str], - targets: Sequence[str], - tokenizer: PreTrainedTokenizer, - max_length: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Preprocess the data by tokenizing. - None for attention mask, ChatGLM will calculate attention mask according to input ids - """ - - labels = [] - input_ids = [] - for source, target in zip(sources, targets): - source_id = tokenizer.encode(text=source, add_special_tokens=False) - target_id = tokenizer.encode(text=target, add_special_tokens=False) - input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id) - # truncate - sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id] - truncate_length = max(0, len(input_id) - max_length) - input_id = input_id[truncate_length:] - if truncate_length == len(source_id) + 1: - input_id = sp_token_list + input_id[1:] - elif truncate_length > len(source_id) + 1: - input_id = sp_token_list + input_id[2:] - - context_length = input_id.index(tokenizer.bos_token_id) - mask_position = context_length - 1 - label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :] - - pad_len = max_length - len(input_id) - input_id = input_id + [tokenizer.pad_token_id] * pad_len - input_ids.append(input_id) - labels.append(label + [IGNORE_INDEX] * pad_len) - return torch.tensor(input_ids), torch.tensor(labels), None - - -class SFTDataset(Dataset): - """ - Dataset for sft model - - Args: - dataset: dataset for supervised model - tokenizer: tokenizer for supervised model - max_length: max length of input - """ - - def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None: - super().__init__() - self.input_ids = [] - - sources = [data["prompt"] for data in dataset] - targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())] - - logger.info("Tokenizing inputs... This may take some time...") - if isinstance(tokenizer, ChatGLMTokenizer): - self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( - sources, targets, tokenizer, max_length - ) - else: - self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) - - logger.info("Loaded dataset.") - - def __len__(self): - length = self.input_ids.shape[0] - return length - - def __getitem__(self, idx): - if self.attention_mask is not None: - return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) - else: - return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) - - -class SupervisedDataset(Dataset): - """Dataset for supervised fine-tuning.""" - - def __init__( - self, - data_path: str, - tokenizer: PreTrainedTokenizer, - max_datasets_size: Optional[int] = None, - max_length: int = 512, - ): - super().__init__() - logger.info("Loading data...") - list_data_dict = jload(data_path) - logger.info(f"Loaded {len(list_data_dict)} examples.") - - if max_datasets_size is not None: - logger.info(f"Limiting dataset to {max_datasets_size} examples.") - list_data_dict = list_data_dict[:max_datasets_size] - - logger.info("Formatting inputs...") - prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] - sources = [ - prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example) - for example in list_data_dict - ] - targets = [example["output"] + tokenizer.eos_token for example in list_data_dict] - - logger.info("Tokenizing inputs... This may take some time...") - if isinstance(tokenizer, ChatGLMTokenizer): - self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( - sources, targets, tokenizer, max_length - ) - else: - self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) - - logger.info("Loaded dataset.") - - def __len__(self): - length = self.input_ids.shape[0] - return length - - def __getitem__(self, idx): - if self.attention_mask is not None: - return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) - else: - return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) diff --git a/applications/Chat/coati/dataset/utils.py b/applications/Chat/coati/dataset/utils.py deleted file mode 100644 index f37fce67a..000000000 --- a/applications/Chat/coati/dataset/utils.py +++ /dev/null @@ -1,22 +0,0 @@ -import io -import json - -import torch.distributed as dist - - -def is_rank_0() -> bool: - return not dist.is_initialized() or dist.get_rank() == 0 - - -def _make_r_io_base(f, mode: str): - if not isinstance(f, io.IOBase): - f = open(f, mode=mode) - return f - - -def jload(f, mode="r"): - """Load a .json file into a dictionary.""" - f = _make_r_io_base(f, mode) - jdict = json.load(f) - f.close() - return jdict diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py deleted file mode 100644 index 941e1994b..000000000 --- a/applications/Chat/coati/experience_maker/naive.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import torch.nn.functional as F -from coati.models.base import Actor, Critic, RewardModel -from coati.models.generation import generate -from coati.models.utils import calc_action_log_probs, compute_reward -from transformers import PreTrainedTokenizer - -from .base import Experience, ExperienceMaker - - -class NaiveExperienceMaker(ExperienceMaker): - """ - Naive experience maker. - """ - - def __init__( - self, - actor: Actor, - critic: Critic, - reward_model: RewardModel, - initial_model: Actor, - tokenizer: PreTrainedTokenizer, - kl_coef: float = 0.1, - ) -> None: - super().__init__(actor, critic, reward_model, initial_model) - self.tokenizer = tokenizer - self.kl_coef = kl_coef - - @torch.no_grad() - def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: - self.actor.eval() - self.critic.eval() - self.initial_model.eval() - self.reward_model.eval() - - # generate sequences - sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs) - - # calculate auxiliary tensors - attention_mask = None - pad_token_id = self.tokenizer.pad_token_id - if pad_token_id is not None: - attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) - - input_len = input_ids.size(1) - eos_token_id = self.tokenizer.eos_token_id - if eos_token_id is None: - action_mask = torch.ones_like(sequences, dtype=torch.bool) - else: - # left padding may be applied, only mask action - action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input - action_mask[:, :input_len] = False - action_mask = action_mask[:, 1:] - action_mask = action_mask[:, -(sequences.size(1) - input_len) :] - num_actions = action_mask.size(1) - - actor_output = self.actor(sequences, attention_mask)["logits"] - action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) - base_model_output = self.initial_model(sequences, attention_mask)["logits"] - base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) - value = self.critic(sequences, attention_mask) - r = self.reward_model(sequences, attention_mask) - reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) - - advantage = reward - value - # TODO(ver217): maybe normalize adv - if advantage.ndim == 1: - advantage = advantage.unsqueeze(-1) - - return Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask) diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py deleted file mode 100644 index 96d40c7c4..000000000 --- a/applications/Chat/coati/kernels/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .wrapper import convert_to_xformer_model, recover_from_xformer_model - -__all__ = [ - "convert_to_xformer_model", - "recover_from_xformer_model", -] diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py deleted file mode 100644 index d1eb13918..000000000 --- a/applications/Chat/coati/kernels/opt_attn.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Optional, Tuple - -import torch -import xformers.ops as xops -from torch import Tensor -from transformers.models.opt.modeling_opt import OPTAttention - - -# This is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py -class XOPTAttention(OPTAttention): - # def _shape(self, tensor: Tensor, seq_len: int, bsz: int): - # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() - - def forward( - self, - hidden_states: Tensor, - key_value_states: Optional[Tensor] = None, - past_key_value: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - layer_head_mask: Optional[Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: - if not self.training: - return super().forward( - hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions - ) - """Input shape: Batch x Time x Channel""" - assert layer_head_mask is None, "Xformers attention does not support layer_head_mask" - assert not output_attentions, "Xformers attention does not support output_attentions" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz).transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = xops.memory_efficient_attention( - query_states, - key_states, - value_states, - attn_bias=xops.LowerTriangularMask(), - p=self.dropout if self.training else 0.0, - scale=self.scaling, - ) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - attn_weights_reshaped = None - - return attn_output, attn_weights_reshaped, past_key_value diff --git a/applications/Chat/coati/kernels/wrapper.py b/applications/Chat/coati/kernels/wrapper.py deleted file mode 100644 index c55bda600..000000000 --- a/applications/Chat/coati/kernels/wrapper.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch.nn as nn -from transformers.models.opt.modeling_opt import OPTAttention - -from .opt_attn import XOPTAttention - - -def convert_to_xformer_model(model: nn.Module) -> nn.Module: - for module in model.modules(): - if isinstance(module, OPTAttention): - module.__class__ = XOPTAttention - return model - - -def recover_from_xformer_model(model: nn.Module) -> nn.Module: - for module in model.modules(): - if isinstance(module, XOPTAttention): - module.__class__ = OPTAttention - return model diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py deleted file mode 100644 index ad4a525b4..000000000 --- a/applications/Chat/coati/models/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .base import Actor, Critic, RewardModel -from .lora import LoRAModule, convert_to_lora_module -from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss - -__all__ = [ - "Actor", - "Critic", - "RewardModel", - "PolicyLoss", - "ValueLoss", - "LogSigLoss", - "LogExpLoss", - "LoRAModule", - "convert_to_lora_module", -] diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py deleted file mode 100644 index 5c9905bb2..000000000 --- a/applications/Chat/coati/models/base/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Union - -import torch.nn as nn - -from .actor import Actor -from .critic import Critic -from .reward_model import RewardModel - - -def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module: - """Get the base model of our wrapper classes. - For Actor, Critic and RewardModel, return ``model.model``, - it's usually a ``transformers.PreTrainedModel``. - - Args: - model (nn.Module): model to get base model from - - Returns: - nn.Module: the base model - """ - assert isinstance( - model, (Actor, Critic, RewardModel) - ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first." - return model.model - - -__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"] diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py deleted file mode 100644 index 8b2b81ed0..000000000 --- a/applications/Chat/coati/models/base/actor.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - -from ..lora import LoRAModule - - -class Actor(LoRAModule): - """ - Actor model base class. - - Args: - model (nn.Module): Actor Model. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None: - super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) - self.model = model - self.convert_to_lora() - - def forward( - self, - input_ids: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None, - **model_kwargs, - ) -> torch.Tensor: - """Returns model output.""" - output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) - return output - diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py deleted file mode 100644 index 8672365f5..000000000 --- a/applications/Chat/coati/models/base/critic.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import torch.nn as nn - -from ..lora import LoRAModule - - -class Critic(LoRAModule): - """ - Critic model base class. - - Args: - model (nn.Module): Critic model. - value_head (nn.Module): Value head to get value. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none" - ) -> None: - super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) - self.model = model - self.value_head = value_head - self.convert_to_lora() - - def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor: - outputs = self.model(sequences, attention_mask=attention_mask) - last_hidden_states = outputs["last_hidden_state"] - sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[ - 0 - ] - sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths] - values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, ) - return values diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py deleted file mode 100644 index e9545d1cd..000000000 --- a/applications/Chat/coati/models/base/reward_model.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - -from ..lora import LoRAModule - - -class RewardModel(LoRAModule): - """ - Reward model base class. - - Args: - model (nn.Module): Reward model. - value_head (nn.Module): Value head to get reward score. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - model: nn.Module, - value_head: Optional[nn.Module] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - ) -> None: - super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) - self.model = model - self.convert_to_lora() - - if value_head is not None: - if value_head.out_features != 1: - raise ValueError("The value head of reward model's output dim should be 1!") - self.value_head = value_head - else: - self.value_head = nn.Linear(model.config.n_embd, 1) - - def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor: - outputs = self.model(sequences, attention_mask=attention_mask) - last_hidden_states = outputs["last_hidden_state"] - sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[ - 0 - ] - sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths] - values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, ) - return values diff --git a/applications/Chat/coati/models/bloom/__init__.py b/applications/Chat/coati/models/bloom/__init__.py deleted file mode 100644 index 7af199a67..000000000 --- a/applications/Chat/coati/models/bloom/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .bloom_actor import BLOOMActor -from .bloom_critic import BLOOMCritic -from .bloom_rm import BLOOMRM - -__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"] diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py deleted file mode 100644 index 73855a224..000000000 --- a/applications/Chat/coati/models/bloom/bloom_actor.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional - -from transformers import BloomConfig, BloomForCausalLM - -from ..base import Actor - - -class BLOOMActor(Actor): - """ - BLOOM Actor model. - - Args: - pretrained (str): Pretrained model name or path. - config (BloomConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = "none", - ) -> None: - if pretrained is not None: - model = BloomForCausalLM.from_pretrained(pretrained) - elif config is not None: - model = BloomForCausalLM(config) - else: - model = BloomForCausalLM(BloomConfig()) - if checkpoint: - model.gradient_checkpointing_enable() - super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py deleted file mode 100644 index b2d838f7f..000000000 --- a/applications/Chat/coati/models/bloom/bloom_critic.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers import BloomConfig, BloomModel - -from ..base import Critic - - -class BLOOMCritic(Critic): - """ - BLOOM Critic model. - - Args: - pretrained (str): Pretrained model name or path. - config (BloomConfig): Model config. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - **kwargs, - ) -> None: - if pretrained is not None: - model = BloomModel.from_pretrained(pretrained) - elif config is not None: - model = BloomModel(config) - else: - model = BloomModel(BloomConfig()) - - value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py deleted file mode 100644 index c09457ddc..000000000 --- a/applications/Chat/coati/models/bloom/bloom_rm.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers import BloomConfig, BloomModel - -from ..base import RewardModel - - -class BLOOMRM(RewardModel): - """ - BLOOM Reward model. - - Args: - pretrained (str): Pretrained model name or path. - config (BloomConfig): Model config. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - ) -> None: - if pretrained is not None: - model = BloomModel.from_pretrained(pretrained) - elif config is not None: - model = BloomModel(config) - else: - model = BloomModel(BloomConfig()) - - value_head = nn.Linear(model.config.hidden_size, 1) - value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) - super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py deleted file mode 100644 index 5956f5a8e..000000000 --- a/applications/Chat/coati/models/chatglm/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .chatglm_actor import ChatGLMActor - -__all__ = ["ChatGLMActor"] diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py deleted file mode 100644 index 00a61561e..000000000 --- a/applications/Chat/coati/models/chatglm/chatglm_actor.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Optional - -from ..base import Actor -from .configuration_chatglm import ChatGLMConfig -from .modeling_chatglm import ChatGLMForConditionalGeneration - - -class ChatGLMActor(Actor): - """ - ChatGLM Actor model. - - Args: - pretrained (str): Pretrained model name or path. - config (ChatGLMConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. - - do not support lora for now. - """ - - def __init__( - self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False - ) -> None: - if pretrained is not None: - model = ChatGLMForConditionalGeneration.from_pretrained(pretrained) - elif config is not None: - model = ChatGLMForConditionalGeneration(config) - else: - model = ChatGLMForConditionalGeneration(ChatGLMConfig()) - if checkpoint: - model.gradient_checkpointing_enable() - super().__init__(model, lora_rank=0, lora_train_bias="none") diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py deleted file mode 100644 index 221ef044b..000000000 --- a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py +++ /dev/null @@ -1,442 +0,0 @@ -""" -This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py -""" -"""Tokenization classes for ChatGLM.""" -import os -from typing import Dict, List, Optional, Union - -import numpy as np -import sentencepiece as spm -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.tokenization_utils_base import BatchEncoding, EncodedInput -from transformers.utils import PaddingStrategy, logging - -logger = logging.get_logger(__name__) - -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "THUDM/chatglm-6b": 2048, -} - - -class TextTokenizer: - def __init__(self, model_path): - self.sp = spm.SentencePieceProcessor() - self.sp.Load(model_path) - self.num_tokens = self.sp.vocab_size() - - def encode(self, text): - return self.sp.EncodeAsIds(text) - - def decode(self, ids: List[int]): - return self.sp.DecodeIds(ids) - - def tokenize(self, text): - return self.sp.EncodeAsPieces(text) - - def convert_tokens_to_string(self, tokens): - return self.sp.DecodePieces(tokens) - - def convert_tokens_to_ids(self, tokens): - return [self.sp.PieceToId(token) for token in tokens] - - def convert_token_to_id(self, token): - return self.sp.PieceToId(token) - - def convert_id_to_token(self, idx): - return self.sp.IdToPiece(idx) - - def __len__(self): - return self.num_tokens - - -class SPTokenizer: - def __init__( - self, - vocab_file, - num_image_tokens=20000, - max_blank_length=80, - byte_fallback=True, - ): - assert vocab_file is not None - self.vocab_file = vocab_file - self.num_image_tokens = num_image_tokens - self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] - self.max_blank_length = max_blank_length - self.byte_fallback = byte_fallback - self.text_tokenizer = TextTokenizer(vocab_file) - - def _get_text_tokenizer(self): - return self.text_tokenizer - - @staticmethod - def get_blank_token(length: int): - assert length >= 2 - return f"<|blank_{length}|>" - - @staticmethod - def get_tab_token(): - return f"<|tab|>" - - @property - def num_text_tokens(self): - return self.text_tokenizer.num_tokens - - @property - def num_tokens(self): - return self.num_image_tokens + self.num_text_tokens - - @staticmethod - def _encode_whitespaces(text: str, max_len: int = 80): - text = text.replace("\t", SPTokenizer.get_tab_token()) - for i in range(max_len, 1, -1): - text = text.replace(" " * i, SPTokenizer.get_blank_token(i)) - return text - - def _preprocess(self, text: str, linebreak=True, whitespaces=True): - if linebreak: - text = text.replace("\n", "") - if whitespaces: - text = self._encode_whitespaces(text, max_len=self.max_blank_length) - return text - - def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]: - """ - @param text: Text to encode. - @param linebreak: Whether to encode newline (\n) in text. - @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. - @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. - @param add_dummy_prefix: Whether to add dummy blank space in the beginning. - """ - text = self._preprocess(text, linebreak, whitespaces) - if not add_dummy_prefix: - text = "" + text - tmp = self._get_text_tokenizer().encode(text) - tokens = [x + self.num_image_tokens for x in tmp] - return tokens if add_dummy_prefix else tokens[2:] - - def postprocess(self, text): - text = text.replace("", "\n") - text = text.replace(SPTokenizer.get_tab_token(), "\t") - for i in range(2, self.max_blank_length + 1): - text = text.replace(self.get_blank_token(i), " " * i) - return text - - def decode(self, text_ids: List[int]) -> str: - ids = [int(_id) - self.num_image_tokens for _id in text_ids] - ids = [_id for _id in ids if _id >= 0] - text = self._get_text_tokenizer().decode(ids) - text = self.postprocess(text) - return text - - def decode_tokens(self, tokens: List[str]) -> str: - text = self._get_text_tokenizer().convert_tokens_to_string(tokens) - text = self.postprocess(text) - return text - - def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]: - """ - @param text: Text to encode. - @param linebreak: Whether to encode newline (\n) in text. - @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. - @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. - @param add_dummy_prefix: Whether to add dummy blank space in the beginning. - """ - text = self._preprocess(text, linebreak, whitespaces) - if not add_dummy_prefix: - text = "" + text - tokens = self._get_text_tokenizer().tokenize(text) - return tokens if add_dummy_prefix else tokens[2:] - - def __getitem__(self, x: Union[int, str]): - if isinstance(x, int): - if x < self.num_image_tokens: - return "".format(x) - else: - return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens) - elif isinstance(x, str): - if x.startswith("") and x[7:-1].isdigit(): - return int(x[7:-1]) - else: - return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens - else: - raise ValueError("The key should be str or int.") - - -class ChatGLMTokenizer(PreTrainedTokenizer): - """ - Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - """ - - vocab_files_names = {"vocab_file": "ice_text.model"} - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask", "position_ids"] - - def __init__( - self, - vocab_file, - do_lower_case=False, - remove_space=False, - bos_token="", - eos_token="", - end_token="", - mask_token="[MASK]", - gmask_token="[gMASK]", - padding_side="left", - pad_token="", - unk_token="", - num_image_tokens=20000, - **kwargs, - ) -> None: - super().__init__( - do_lower_case=do_lower_case, - remove_space=remove_space, - padding_side=padding_side, - bos_token=bos_token, - eos_token=eos_token, - end_token=end_token, - mask_token=mask_token, - gmask_token=gmask_token, - pad_token=pad_token, - unk_token=unk_token, - num_image_tokens=num_image_tokens, - **kwargs, - ) - - self.do_lower_case = do_lower_case - self.remove_space = remove_space - self.vocab_file = vocab_file - - self.bos_token = bos_token - self.eos_token = eos_token - self.end_token = end_token - self.mask_token = mask_token - self.gmask_token = gmask_token - - self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens) - - """ Initialisation """ - - @property - def gmask_token_id(self) -> Optional[int]: - if self.gmask_token is None: - return None - return self.convert_tokens_to_ids(self.gmask_token) - - @property - def end_token_id(self) -> Optional[int]: - """ - `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been - set. - """ - if self.end_token is None: - return None - return self.convert_tokens_to_ids(self.end_token) - - @property - def vocab_size(self): - """Returns vocab size""" - return self.sp_tokenizer.num_tokens - - def get_vocab(self): - """Returns vocab as a dict""" - vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} - vocab.update(self.added_tokens_encoder) - return vocab - - def preprocess_text(self, inputs): - if self.remove_space: - outputs = " ".join(inputs.strip().split()) - else: - outputs = inputs - - if self.do_lower_case: - outputs = outputs.lower() - - return outputs - - def _tokenize(self, text, **kwargs): - """Returns a tokenized string.""" - text = self.preprocess_text(text) - - seq = self.sp_tokenizer.tokenize(text) - - return seq - - def convert_tokens_to_string(self, tokens: List[str]) -> str: - return self.sp_tokenizer.decode_tokens(tokens) - - def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: - if isinstance(token_ids, int): - token_ids = [token_ids] - if len(token_ids) == 0: - return "" - if self.pad_token_id in token_ids: # remove pad - token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) - return super()._decode(token_ids, **kwargs) - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self.sp_tokenizer[token] - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - return self.sp_tokenizer[index] - - def save_vocabulary(self, save_directory, filename_prefix=None): - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - filename_prefix (`str`, *optional*): - An optional prefix to add to the named of the saved files. - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - if os.path.isdir(save_directory): - vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) - else: - vocab_file = save_directory - - with open(self.vocab_file, "rb") as fin: - proto_str = fin.read() - - with open(vocab_file, "wb") as writer: - writer.write(proto_str) - - return (vocab_file,) - - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and - adding special tokens. A BERT sequence has the following format: - - - single sequence: `[CLS] X [SEP]` - - pair of sequences: `[CLS] A [SEP] B [SEP]` - - Args: - token_ids_0 (`List[int]`): - List of IDs to which the special tokens will be added. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. - """ - gmask_id = self.sp_tokenizer[self.gmask_token] - self.sp_tokenizer[self.eos_token] - token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] - if token_ids_1 is not None: - token_ids_0 = token_ids_0 + token_ids_1 - return token_ids_0 - - def _pad( - self, - encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], - max_length: Optional[int] = None, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, - ) -> dict: - """ - Pad encoded inputs (on left/right and up to predefined length or max length in the batch) - - Args: - encoded_inputs: - Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). - max_length: maximum length of the returned list and optionally padding length (see below). - Will truncate by taking into account the special tokens. - padding_strategy: PaddingStrategy to use for padding. - - - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: - - - 'left': pads on the left of the sequences - - 'right': pads on the right of the sequences - pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. - This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability - `>= 7.5` (Volta). - return_attention_mask: - (optional) Set to False to avoid returning attention mask (default: set to model specifics) - """ - # Load from model defaults - bos_token_id = self.sp_tokenizer[self.bos_token] - mask_token_id = self.sp_tokenizer[self.mask_token] - gmask_token_id = self.sp_tokenizer[self.gmask_token] - assert self.padding_side == "left" - - required_input = encoded_inputs[self.model_input_names[0]] - seq_length = len(required_input) - - if padding_strategy == PaddingStrategy.LONGEST: - max_length = len(required_input) - - if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length - - # Initialize attention mask if not present. - if max_length is not None: - if "attention_mask" not in encoded_inputs: - if bos_token_id in required_input: - context_length = required_input.index(bos_token_id) - else: - context_length = seq_length - attention_mask = np.ones((1, seq_length, seq_length)) - attention_mask = np.tril(attention_mask) - attention_mask[:, :, :context_length] = 1 - attention_mask = np.bool_(attention_mask < 0.5) - encoded_inputs["attention_mask"] = attention_mask - - if "position_ids" not in encoded_inputs: - if bos_token_id in required_input: - context_length = required_input.index(bos_token_id) - else: - context_length = seq_length - position_ids = np.arange(seq_length, dtype=np.int64) - mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id - if mask_token in required_input: - mask_position = required_input.index(mask_token) - position_ids[context_length:] = mask_position - block_position_ids = np.concatenate( - [ - np.zeros(context_length, dtype=np.int64), - np.arange(1, seq_length - context_length + 1, dtype=np.int64), - ] - ) - encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) - - if needs_to_be_padded: - difference = max_length - len(required_input) - - if "attention_mask" in encoded_inputs: - encoded_inputs["attention_mask"] = np.pad( - encoded_inputs["attention_mask"], - pad_width=[(0, 0), (difference, 0), (difference, 0)], - mode="constant", - constant_values=True, - ) - if "token_type_ids" in encoded_inputs: - encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ - "token_type_ids" - ] - if "special_tokens_mask" in encoded_inputs: - encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] - if "position_ids" in encoded_inputs: - encoded_inputs["position_ids"] = np.pad( - encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)] - ) - encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input - - return encoded_inputs diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py deleted file mode 100644 index a6d2ccd18..000000000 --- a/applications/Chat/coati/models/chatglm/configuration_chatglm.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/configuration_chatglm.py -""" - -""" ChatGLM model configuration """ - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class ChatGLMConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`~ChatGLMModel`]. - It is used to instantiate an ChatGLM model according to the specified arguments, defining the model - architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used - to control the model outputs. Read the documentation from [`PretrainedConfig`] - for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 150528): - Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`~ChatGLMModel`] or - [`~TFChatGLMModel`]. - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the encoder layers and the pooler layer. - num_hidden_layers (`int`, *optional*, defaults to 28): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - inner_hidden_size (`int`, *optional*, defaults to 16384): - Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - max_sequence_length (`int`, *optional*, defaults to 512): - The maximum sequence length that this model might ever be used with. - Typically set this to something large just in case (e.g., 512 or 1024 or 2048). - layernorm_epsilon (`float`, *optional*, defaults to 1e-5): - The epsilon used by the layer normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether the model should return the last key/values attentions (not used by all models). - Example: - - ```python - >>> from configuration_chatglm import ChatGLMConfig - >>> from modeling_chatglm import ChatGLMModel - - >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration - >>> configuration = ChatGLMConfig() - - >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration - >>> model = ChatGLMModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "chatglm" - - def __init__( - self, - vocab_size=130528, - hidden_size=4096, - num_layers=28, - num_attention_heads=32, - layernorm_epsilon=1e-5, - use_cache=True, - bos_token_id=130004, - eos_token_id=130005, - mask_token_id=130000, - gmask_token_id=130001, - pad_token_id=3, - max_sequence_length=2048, - inner_hidden_size=16384, - position_encoding_2d=True, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs, - ): - self.num_layers = num_layers - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.max_sequence_length = max_sequence_length - self.layernorm_epsilon = layernorm_epsilon - self.inner_hidden_size = inner_hidden_size - self.use_cache = use_cache - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.mask_token_id = mask_token_id - self.gmask_token_id = gmask_token_id - self.position_encoding_2d = position_encoding_2d - self.quantization_bit = quantization_bit - self.pre_seq_len = pre_seq_len - self.prefix_projection = prefix_projection - - super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py deleted file mode 100644 index d1d15c68f..000000000 --- a/applications/Chat/coati/models/chatglm/modeling_chatglm.py +++ /dev/null @@ -1,1477 +0,0 @@ -""" -This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/modeling_chatglm.py -""" - -""" PyTorch ChatGLM model. """ - -import copy -import math -import os -import re -import sys -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm -from torch.nn.utils import skip_init -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != "darwin": - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" -_CONFIG_FOR_DOC = "ChatGLM6BConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm-6b", - # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm -] - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert ( - pointer.shape == array.shape - ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(config.hidden_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2), - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -@torch.jit.script -def gelu_impl(x): - """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) - - -def gelu(x): - return gelu_impl(x) - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, precision=torch.half, learnable=False): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - inv_freq = inv_freq.half() - self.learnable = learnable - if learnable: - self.inv_freq = torch.nn.Parameter(inv_freq) - self.max_seq_len_cached = None - else: - self.register_buffer("inv_freq", inv_freq) - self.max_seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - self.precision = precision - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - pass - - def forward(self, x, seq_dim=1, seq_len=None): - if seq_len is None: - seq_len = x.shape[seq_dim] - if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): - self.max_seq_len_cached = None if self.learnable else seq_len - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - - # [sx, 1 (b * np), hn] - cos_cached = emb.cos()[:, None, :] - sin_cached = emb.sin()[:, None, :] - if self.precision == torch.bfloat16: - cos_cached = cos_cached.bfloat16() - sin_cached = sin_cached.bfloat16() - if self.learnable: - return cos_cached, sin_cached - self.cos_cached, self.sin_cached = cos_cached, sin_cached - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] - - def _apply(self, fn): - if self.cos_cached is not None: - self.cos_cached = fn(self.cos_cached) - if self.sin_cached is not None: - self.sin_cached = fn(self.sin_cached) - return super()._apply(fn) - - -def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions - - -@torch.jit.script -def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): - # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] - cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding( - position_id, sin.squeeze(1) - ).unsqueeze(2) - q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - return q, k - - -def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask, - hidden_size_per_partition, - layer_id, - layer_past=None, - scaling_attention_score=True, - use_cache=False, -): - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - key_layer = torch.cat((past_key, key_layer), dim=0) - value_layer = torch.cat((past_value, value_layer), dim=0) - - # seqlen, batch, num_attention_heads, hidden_size_per_attention_head - seq_len, b, nh, hidden_size = key_layer.shape - - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - query_key_layer_scaling_coeff = float(layer_id + 1) - if scaling_attention_score: - query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - matmul_result = torch.zeros( - 1, - 1, - 1, - dtype=query_layer.dtype, - device=query_layer.device, - ) - - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=1.0, - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - if self.scale_mask_softmax: - self.scale_mask_softmax.scale = query_key_layer_scaling_coeff - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) - else: - if not (attention_mask == 0).all(): - # if auto-regressive, skip - attention_scores.masked_fill_(attention_mask, -10000.0) - dtype = attention_scores.dtype - attention_scores = attention_scores.float() - attention_scores = attention_scores * query_key_layer_scaling_coeff - - attention_probs = F.softmax(attention_scores, dim=-1) - - attention_probs = attention_probs.type(dtype) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, present, attention_probs) - - return outputs - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class SelfAttention(torch.nn.Module): - def __init__( - self, - hidden_size, - num_attention_heads, - layer_id, - hidden_size_per_attention_head=None, - bias=True, - params_dtype=torch.float, - position_encoding_2d=True, - empty_init=True, - ): - if empty_init: - init_method = skip_init - else: - init_method = default_init - super(SelfAttention, self).__init__() - - self.layer_id = layer_id - self.hidden_size = hidden_size - self.hidden_size_per_partition = hidden_size - self.num_attention_heads = num_attention_heads - self.num_attention_heads_per_partition = num_attention_heads - self.position_encoding_2d = position_encoding_2d - self.rotary_emb = RotaryEmbedding( - self.hidden_size // (self.num_attention_heads * 2) - if position_encoding_2d - else self.hidden_size // self.num_attention_heads, - base=10000, - precision=torch.half, - learnable=False, - ) - - self.scale_mask_softmax = None - - if hidden_size_per_attention_head is None: - self.hidden_size_per_attention_head = hidden_size // num_attention_heads - else: - self.hidden_size_per_attention_head = hidden_size_per_attention_head - - self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head - - # Strided linear layer. - self.query_key_value = init_method( - torch.nn.Linear, - hidden_size, - 3 * self.inner_hidden_size, - bias=bias, - dtype=params_dtype, - ) - - self.dense = init_method( - torch.nn.Linear, - self.inner_hidden_size, - hidden_size, - bias=bias, - dtype=params_dtype, - ) - - @staticmethod - def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False): - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - """ - hidden_states: [seq_len, batch, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # [seq_len, batch, 3 * hidden_size] - mixed_raw_layer = self.query_key_value(hidden_states) - - # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head] - new_tensor_shape = mixed_raw_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) - - # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] - (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) - - if self.position_encoding_2d: - q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) - k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) - cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) - position_ids, block_position_ids = ( - position_ids[:, 0, :].transpose(0, 1).contiguous(), - position_ids[:, 1, :].transpose(0, 1).contiguous(), - ) - q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) - q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) - query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) - key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) - else: - position_ids = position_ids.transpose(0, 1) - cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) - # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] - query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) - - # [seq_len, batch, hidden_size] - context_layer, present, attention_probs = attention_fn( - self=self, - query_layer=query_layer, - key_layer=key_layer, - value_layer=value_layer, - attention_mask=attention_mask, - hidden_size_per_partition=self.hidden_size_per_partition, - layer_id=layer_id, - layer_past=layer_past, - use_cache=use_cache, - ) - - output = self.dense(context_layer) - - outputs = (output, present) - - if output_attentions: - outputs += (attention_probs,) - - return outputs # output, present, attention_probs - - -class GEGLU(torch.nn.Module): - def __init__(self): - super().__init__() - self.activation_fn = F.gelu - - def forward(self, x): - # dim=-1 breaks in jit for pt<1.10 - x1, x2 = x.chunk(2, dim=(x.ndim - 1)) - return x1 * self.activation_fn(x2) - - -class GLU(torch.nn.Module): - def __init__( - self, - hidden_size, - inner_hidden_size=None, - layer_id=None, - bias=True, - activation_func=gelu, - params_dtype=torch.float, - empty_init=True, - ): - super(GLU, self).__init__() - if empty_init: - init_method = skip_init - else: - init_method = default_init - self.layer_id = layer_id - self.activation_func = activation_func - - # Project to 4h. - self.hidden_size = hidden_size - if inner_hidden_size is None: - inner_hidden_size = 4 * hidden_size - self.inner_hidden_size = inner_hidden_size - self.dense_h_to_4h = init_method( - torch.nn.Linear, - self.hidden_size, - self.inner_hidden_size, - bias=bias, - dtype=params_dtype, - ) - # Project back to h. - self.dense_4h_to_h = init_method( - torch.nn.Linear, - self.inner_hidden_size, - self.hidden_size, - bias=bias, - dtype=params_dtype, - ) - - def forward(self, hidden_states): - """ - hidden_states: [seq_len, batch, hidden_size] - """ - - # [seq_len, batch, inner_hidden_size] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - - intermediate_parallel = self.activation_func(intermediate_parallel) - - output = self.dense_4h_to_h(intermediate_parallel) - - return output - - -class GLMBlock(torch.nn.Module): - def __init__( - self, - hidden_size, - num_attention_heads, - layernorm_epsilon, - layer_id, - inner_hidden_size=None, - hidden_size_per_attention_head=None, - layernorm=LayerNorm, - use_bias=True, - params_dtype=torch.float, - num_layers=28, - position_encoding_2d=True, - empty_init=True, - ): - super(GLMBlock, self).__init__() - # Set output layer initialization if not provided. - - self.layer_id = layer_id - - # Layernorm on the input data. - self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - - self.position_encoding_2d = position_encoding_2d - - # Self attention. - self.attention = SelfAttention( - hidden_size, - num_attention_heads, - layer_id, - hidden_size_per_attention_head=hidden_size_per_attention_head, - bias=use_bias, - params_dtype=params_dtype, - position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init, - ) - - # Layernorm on the input data. - self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - - self.num_layers = num_layers - - # GLU - self.mlp = GLU( - hidden_size, - inner_hidden_size=inner_hidden_size, - bias=use_bias, - layer_id=layer_id, - params_dtype=params_dtype, - empty_init=empty_init, - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - """ - hidden_states: [seq_len, batch, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # Layer norm at the begining of the transformer layer. - # [seq_len, batch, hidden_size] - attention_input = self.input_layernorm(hidden_states) - - # Self attention. - attention_outputs = self.attention( - attention_input, - position_ids, - attention_mask=attention_mask, - layer_id=layer_id, - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - attention_output = attention_outputs[0] - - outputs = attention_outputs[1:] - - # Residual connection. - alpha = (2 * self.num_layers) ** 0.5 - hidden_states = attention_input * alpha + attention_output - - mlp_input = self.post_attention_layernorm(hidden_states) - - # MLP. - mlp_output = self.mlp(mlp_input) - - # Second residual connection. - output = mlp_input * alpha + mlp_output - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, device): - batch_size, seq_length = input_ids.shape - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) - attention_mask.tril_() - for i, context_length in enumerate(context_lengths): - attention_mask[i, :, :context_length] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() - - return attention_mask - - def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): - batch_size, seq_length = input_ids.shape - if use_gmasks is None: - use_gmasks = [False] * batch_size - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - if self.position_encoding_2d: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [ - torch.cat( - ( - torch.zeros(context_length, dtype=torch.long, device=device), - torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1, - ) - ) - for context_length in context_lengths - ] - block_position_ids = torch.stack(block_position_ids, dim=0) - position_ids = torch.stack((position_ids, block_position_ids), dim=1) - else: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - for i, context_length in enumerate(context_lengths): - if not use_gmasks[i]: - position_ids[i, context_length:] = mask_positions[i] - - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, ChatGLMModel): - module.gradient_checkpointing = value - - -CHATGLM_6B_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general - usage and behavior. - - Parameters: - config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CHATGLM_6B_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`ChatGLM6BTokenizer`]. - See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range `[0, config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert *input_ids* indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", - CHATGLM_6B_START_DOCSTRING, -) -class ChatGLMModel(ChatGLMPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well - as a decoder, in which case a layer of cross-attention is added between - the self-attention layers, following the architecture described in [Attention is - all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, - Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the - `is_decoder` argument of the configuration set to `True`. - To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` - argument and `add_cross_attention` set to `True`; an - `encoder_hidden_states` is then expected as an input to the forward pass. - """ - - def __init__(self, config: ChatGLMConfig, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - # recording parameters - self.max_sequence_length = config.max_sequence_length - self.hidden_size = config.hidden_size - self.params_dtype = torch.half - self.num_attention_heads = config.num_attention_heads - self.vocab_size = config.vocab_size - self.num_layers = config.num_layers - self.layernorm_epsilon = config.layernorm_epsilon - self.inner_hidden_size = config.inner_hidden_size - self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads - self.position_encoding_2d = config.position_encoding_2d - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - - self.word_embeddings = init_method( - torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype - ) - self.gradient_checkpointing = False - - def get_layer(layer_id): - return GLMBlock( - self.hidden_size, - self.num_attention_heads, - self.layernorm_epsilon, - layer_id, - inner_hidden_size=self.inner_hidden_size, - hidden_size_per_attention_head=self.hidden_size_per_attention_head, - layernorm=LayerNorm, - use_bias=True, - params_dtype=self.params_dtype, - position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init, - ) - - self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)]) - - # Final layer norm before output. - self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) - - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - # total_params = sum(p.numel() for p in self.parameters()) - # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) - # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) - - def get_input_embeddings(self): - return self.word_embeddings - - def set_input_embeddings(self, new_embeddings: torch.Tensor): - self.word_embeddings = new_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.num_attention_heads, - self.hidden_size // self.num_attention_heads, - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - # past_key_values = [(v[0], v[1]) for v in past_key_values] - return past_key_values - - @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - if past_key_values is None: - if self.pre_seq_len is not None: - past_key_values = self.get_prompt( - batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype - ) - else: - past_key_values = tuple([None] * len(self.layers)) - - if attention_mask is None: - attention_mask = self.get_masks(input_ids, device=input_ids.device) - - if position_ids is None: - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - - position_ids = self.get_position_ids( - input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks - ) - - if self.pre_seq_len is not None and attention_mask is not None: - prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( - attention_mask.device - ) - prefix_attention_mask = (prefix_attention_mask < 0.5).bool() - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) - - # [seq_len, batch, hidden_size] - hidden_states = inputs_embeds.transpose(0, 1) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if attention_mask is None: - attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() - else: - attention_mask = attention_mask.to(hidden_states.device) - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - layer_past = past_key_values[i] - - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - position_ids, - attention_mask, - torch.tensor(i), - layer_past, - use_cache, - output_attentions, - ) - else: - layer_ret = layer( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - layer_id=torch.tensor(i), - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = layer_ret[0] - - if use_cache: - presents = presents + (layer_ret[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) - - # Final layer norm. - hidden_states = self.final_layernorm(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - - # self.hidden_size = config.hidden_size - # self.params_dtype = torch.half - # self.vocab_size = config.vocab_size - self.max_sequence_length = config.max_sequence_length - - self.position_encoding_2d = config.position_encoding_2d - - self.transformer = ChatGLMModel(config, empty_init=empty_init) - - self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half) - - self.config = config - - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - if attention_mask is not None and attention_mask.dtype == torch.bool: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3 - ) - new_attention_mask = attention_mask[:, :, -1:].clone() - new_attention_mask[..., -1] = False - model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id[:, 1, :] += 1 - model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) - - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past: Optional[torch.Tensor] = None, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - **kwargs, - ) -> dict: - batch_size, seq_length = input_ids.shape - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - - # only last token for input_ids if past is not None - if past is not None or past_key_values is not None: - last_token = input_ids[:, -1].unsqueeze(-1) - if attention_mask is not None and attention_mask.dtype == torch.bool: - attention_mask = attention_mask[:, :, -1:] - else: - attention_mask = None - if position_ids is not None: - position_ids = position_ids[..., -1:] - else: - context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] - if self.position_encoding_2d: - position_ids = torch.tensor( - [ - [mask_position, seq_length - context_length] - for mask_position, context_length in zip(mask_positions, context_lengths) - ], - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(-1) - else: - position_ids = torch.tensor( - [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device - ).unsqueeze(-1) - - if past is None: - past = past_key_values - return { - "input_ids": last_token, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask, - } - else: - if attention_mask is not None and attention_mask.dtype != torch.bool: - logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") - attention_mask = None - if attention_mask is None: - attention_mask = self.get_masks(input_ids, device=input_ids.device) - if position_ids is None: - position_ids = self.get_position_ids( - input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks - ) - - return { - "input_ids": input_ids, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask, - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - punkts = [ - [",", ","], - ["!", "!"], - [":", ":"], - [";", ";"], - ["\?", "?"], - ] - for item in punkts: - response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) - response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) - return response - - @torch.no_grad() - def chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = None, - max_length: int = 2048, - num_beams=1, - do_sample=True, - top_p=0.7, - temperature=0.95, - logits_processor=None, - **kwargs, - ): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = { - "max_length": max_length, - "num_beams": num_beams, - "do_sample": do_sample, - "top_p": top_p, - "temperature": temperature, - "logits_processor": logits_processor, - **kwargs, - } - if not history: - prompt = query - else: - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = None, - max_length: int = 2048, - do_sample=True, - top_p=0.7, - temperature=0.95, - logits_processor=None, - **kwargs, - ): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = { - "max_length": max_length, - "do_sample": do_sample, - "top_p": top_p, - "temperature": temperature, - "logits_processor": logits_processor, - **kwargs, - } - if not history: - prompt = query - else: - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - for outputs in self.stream_generate(**inputs, **gen_kwargs): - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] - response = tokenizer.decode(outputs) - response = self.process_response(response) - new_history = history + [(query, response)] - yield response, new_history - - @torch.no_grad() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - yield input_ids - - def quantize(self, bits: int, empty_init=False, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) - return self diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py deleted file mode 100644 index 4ab0cdc8a..000000000 --- a/applications/Chat/coati/models/generation.py +++ /dev/null @@ -1,152 +0,0 @@ -from typing import Any, Callable, Optional - -import torch -import torch.distributed as dist -from transformers import PreTrainedTokenizer - -from .base import Actor - -try: - from transformers.generation_logits_process import ( - LogitsProcessorList, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, - ) -except ImportError: - from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper - - -def _prepare_logits_processor( - top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None -) -> LogitsProcessorList: - processor_list = LogitsProcessorList() - if temperature is not None and temperature != 1.0: - processor_list.append(TemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - processor_list.append(TopKLogitsWarper(top_k)) - if top_p is not None and top_p < 1.0: - processor_list.append(TopPLogitsWarper(top_p)) - return processor_list - - -def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: - if dist.is_initialized() and dist.get_world_size() > 1: - # consider DP - unfinished_sequences = unfinished_sequences.clone() - dist.all_reduce(unfinished_sequences) - return unfinished_sequences.max() == 0 - - -def _sample( - model: Actor, - input_ids: torch.Tensor, - max_length: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs, -) -> torch.Tensor: - if input_ids.size(1) >= max_length: - return input_ids - - logits_processor = _prepare_logits_processor(top_k, top_p, temperature) - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - - for _ in range(input_ids.size(1), max_length): - model_inputs = ( - prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids} - ) - outputs = model(**model_inputs) - - # NOTE: this is correct only in left padding mode - next_token_logits = outputs["logits"][:, -1, :] - next_token_logits = logits_processor(input_ids, next_token_logits) - # sample - probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if update_model_kwargs_fn is not None: - model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) - - # stop when each sentence is finished if early_stopping=True - if early_stopping and _is_sequence_finished(unfinished_sequences): - break - - return input_ids - - -@torch.no_grad() -def generate( - model: Actor, - input_ids: torch.Tensor, - tokenizer: PreTrainedTokenizer, - max_length: int, - num_beams: int = 1, - do_sample: bool = True, - early_stopping: bool = False, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs, -) -> torch.Tensor: - """Generate token sequence. The returned sequence is input_ids + generated_tokens. - - Args: - model (nn.Module): model - input_ids (torch.Tensor): input sequence - max_length (int): max length of the returned sequence - num_beams (int, optional): number of beams. Defaults to 1. - do_sample (bool, optional): whether to do sample. Defaults to True. - early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. - top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. - top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. - temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. - prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. - update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. - """ - assert tokenizer.padding_side == "left", "Current generation only supports left padding." - is_greedy_gen_mode = (num_beams == 1) and do_sample is False - is_sample_gen_mode = (num_beams == 1) and do_sample is True - is_beam_gen_mode = (num_beams > 1) and do_sample is False - if is_greedy_gen_mode: - # run greedy search - raise NotImplementedError - elif is_sample_gen_mode: - # run sample - return _sample( - model, - input_ids, - max_length, - early_stopping=early_stopping, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - top_k=top_k, - top_p=top_p, - temperature=temperature, - prepare_inputs_fn=prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, - **model_kwargs, - ) - elif is_beam_gen_mode: - raise NotImplementedError - else: - raise ValueError("Unsupported generation mode") diff --git a/applications/Chat/coati/models/gpt/__init__.py b/applications/Chat/coati/models/gpt/__init__.py deleted file mode 100644 index 823cf4a75..000000000 --- a/applications/Chat/coati/models/gpt/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .gpt_actor import GPTActor -from .gpt_critic import GPTCritic -from .gpt_rm import GPTRM - -__all__ = ["GPTActor", "GPTCritic", "GPTRM"] diff --git a/applications/Chat/coati/models/gpt/gpt_actor.py b/applications/Chat/coati/models/gpt/gpt_actor.py deleted file mode 100644 index a7e4b9bc3..000000000 --- a/applications/Chat/coati/models/gpt/gpt_actor.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional - -from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel - -from ..base import Actor - - -class GPTActor(Actor): - """ - GPT Actor model. - - Args: - pretrained (str): Pretrained model name or path. - config (GPT2Config): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): Rank of the LoRa layer. - lora_train_bias (str): Bias training strategy for the LoRa layer. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = "none", - **kwargs, - ) -> None: - if pretrained is not None: - model = GPT2LMHeadModel.from_pretrained(pretrained) - elif config is not None: - model = GPT2LMHeadModel(config) - else: - model = GPT2LMHeadModel(GPT2Config()) - if checkpoint: - model.gradient_checkpointing_enable() - super().__init__(model, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py deleted file mode 100644 index 22ab36dea..000000000 --- a/applications/Chat/coati/models/gpt/gpt_critic.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from transformers.models.gpt2.modeling_gpt2 import GPT2Model - -from ..base import Critic - - -class GPTCritic(Critic): - """ - GPT Critic model. - - Args: - pretrained (str): Pretrained model name or path. - config (GPT2Config): Model config. - lora_rank (int): Rank of the LO-RA decomposition. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - **kwargs, - ) -> None: - if pretrained is not None: - model = GPT2Model.from_pretrained(pretrained) - elif config is not None: - model = GPT2Model(config) - else: - model = GPT2Model(GPT2Config()) - - value_head = nn.Linear(model.config.n_embd, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py deleted file mode 100644 index 8edfc4008..000000000 --- a/applications/Chat/coati/models/gpt/gpt_rm.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from transformers.models.gpt2.modeling_gpt2 import GPT2Model - -from ..base import RewardModel - - -class GPTRM(RewardModel): - """ - GPT Reward model. - - Args: - pretrained (str): Pretrained model name or path. - config (GPT2Config): Model config. - lora_rank (int): Rank of the low-rank approximation. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - ) -> None: - if pretrained is not None: - model = GPT2Model.from_pretrained(pretrained) - elif config is not None: - model = GPT2Model(config) - else: - model = GPT2Model(GPT2Config()) - - value_head = nn.Linear(model.config.n_embd, 1) - value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) - super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/llama/__init__.py b/applications/Chat/coati/models/llama/__init__.py deleted file mode 100644 index c87d73253..000000000 --- a/applications/Chat/coati/models/llama/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .llama_actor import LlamaActor -from .llama_critic import LlamaCritic -from .llama_rm import LlamaRM - -__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"] diff --git a/applications/Chat/coati/models/llama/llama_actor.py b/applications/Chat/coati/models/llama/llama_actor.py deleted file mode 100644 index f1d940683..000000000 --- a/applications/Chat/coati/models/llama/llama_actor.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional - -from transformers import LlamaConfig, LlamaForCausalLM - -from ..base import Actor - - -class LlamaActor(Actor): - """ - Llama Actor model. - - Args: - pretrained (str): Pretrained model name or path. - config (LlamaConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = "none", - ) -> None: - if pretrained is not None: - model = LlamaForCausalLM.from_pretrained(pretrained) - elif config is not None: - model = LlamaForCausalLM(config) - else: - model = LlamaForCausalLM(LlamaConfig()) - - if checkpoint: - model.gradient_checkpointing_enable() - - super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py deleted file mode 100644 index 000dce17c..000000000 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers import LlamaConfig, LlamaModel - -from ..base import Critic - - -class LlamaCritic(Critic): - """ - Llama Critic model. - - Args: - pretrained (str): Pretrained model name or path. - config (LlamaConfig): Model config. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - **kwargs, - ) -> None: - if pretrained is not None: - model = LlamaModel.from_pretrained(pretrained) - elif config is not None: - model = LlamaModel(config) - else: - model = LlamaModel(LlamaConfig()) - - value_head = nn.Linear(model.config.hidden_size, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py deleted file mode 100644 index 43bc9e638..000000000 --- a/applications/Chat/coati/models/llama/llama_rm.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers import LlamaConfig, LlamaModel - -from ..base import RewardModel - - -class LlamaRM(RewardModel): - """ - Llama Reward model. - - Args: - pretrained (str): Pretrained model name or path. - config (LlamaConfig): Model config. - lora_rank (int): LoRA rank. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - ) -> None: - if pretrained is not None: - model = LlamaModel.from_pretrained(pretrained) - elif config is not None: - model = LlamaModel(config) - else: - model = LlamaModel(LlamaConfig()) - - value_head = nn.Linear(model.config.hidden_size, 1) - value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) - - super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py deleted file mode 100644 index 687bd0f7b..000000000 --- a/applications/Chat/coati/models/loss.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - -from .utils import masked_mean - - -class GPTLMLoss(nn.Module): - """ - GPT Language Model Loss - """ - - def __init__(self): - super().__init__() - # NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py - self.loss = nn.CrossEntropyLoss() - - def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - -class PolicyLoss(nn.Module): - """ - Policy Loss for PPO - """ - - def __init__(self, clip_eps: float = 0.2) -> None: - super().__init__() - self.clip_eps = clip_eps - - def forward( - self, - log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - advantages: torch.Tensor, - action_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - ratio = (log_probs - old_log_probs).exp() - surr1 = ratio * advantages - surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages - loss = -torch.min(surr1, surr2) - if action_mask is not None: - loss = masked_mean(loss, action_mask) - loss = loss.mean() - return loss - - -class ValueLoss(nn.Module): - """ - Value Loss for PPO - """ - - def __init__(self, clip_eps: float = 0.4) -> None: - super().__init__() - self.clip_eps = clip_eps - - def forward( - self, - values: torch.Tensor, - old_values: torch.Tensor, - reward: torch.Tensor, - action_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) - surr1 = (values_clipped - reward) ** 2 - surr2 = (values - reward) ** 2 - loss = torch.max(surr1, surr2) - loss = loss.mean() - return 0.5 * loss - - -class LogSigLoss(nn.Module): - """ - Pairwise Loss for Reward Model - Details: https://arxiv.org/abs/2203.02155 - """ - - def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: - probs = torch.sigmoid(chosen_reward - reject_reward) - log_probs = torch.log(probs) - loss = -log_probs.mean() - return loss - - -class LogExpLoss(nn.Module): - """ - Pairwise Loss for Reward Model - Details: https://arxiv.org/abs/2204.05862 - """ - - def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: - loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() - return loss diff --git a/applications/Chat/coati/models/opt/__init__.py b/applications/Chat/coati/models/opt/__init__.py deleted file mode 100644 index e37d6e45c..000000000 --- a/applications/Chat/coati/models/opt/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .opt_actor import OPTActor -from .opt_critic import OPTCritic -from .opt_rm import OPTRM - -__all__ = ["OPTActor", "OPTCritic", "OPTRM"] diff --git a/applications/Chat/coati/models/opt/opt_actor.py b/applications/Chat/coati/models/opt/opt_actor.py deleted file mode 100644 index cd8908e13..000000000 --- a/applications/Chat/coati/models/opt/opt_actor.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional - -from transformers.models.opt.configuration_opt import OPTConfig -from transformers.models.opt.modeling_opt import OPTForCausalLM - -from ..base import Actor - - -class OPTActor(Actor): - """ - OPT Actor model. - - Args: - pretrained (str): Pretrained model name or path. - config (OPTConfig): Model config. - checkpoint (bool): Enable gradient checkpointing. - lora_rank (int): Rank of the low-rank approximation. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = "none", - ) -> None: - if pretrained is not None: - model = OPTForCausalLM.from_pretrained(pretrained) - elif config is not None: - model = OPTForCausalLM(config) - else: - model = OPTForCausalLM(OPTConfig()) - if checkpoint: - model.gradient_checkpointing_enable() - super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py deleted file mode 100644 index f37d28812..000000000 --- a/applications/Chat/coati/models/opt/opt_critic.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers.models.opt.configuration_opt import OPTConfig -from transformers.models.opt.modeling_opt import OPTModel - -from ..base import Critic - - -class OPTCritic(Critic): - """ - OPT Critic model. - - Args: - pretrained (str): Pretrained model name or path. - config (OPTConfig): Model config. - lora_rank (int): Rank of the low-rank approximation. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - **kwargs, - ) -> None: - if pretrained is not None: - model = OPTModel.from_pretrained(pretrained) - elif config is not None: - model = OPTModel(config) - else: - model = OPTModel(OPTConfig()) - - value_head = nn.Linear(model.config.word_embed_proj_dim, 1) - super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py deleted file mode 100644 index 893708344..000000000 --- a/applications/Chat/coati/models/opt/opt_rm.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from transformers import OPTConfig, OPTModel - -from ..base import RewardModel - - -class OPTRM(RewardModel): - """ - OPT Reward model. - - Args: - pretrained (str): Pretrained model name or path. - config (OPTConfig): Model config. - lora_rank (int): Rank of the low-rank approximation. - lora_train_bias (str): LoRA bias training mode. - """ - - def __init__( - self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = "none", - ) -> None: - if pretrained is not None: - model = OPTModel.from_pretrained(pretrained) - elif config is not None: - model = OPTModel(config) - else: - model = OPTModel(OPTConfig()) - - value_head = nn.Linear(model.config.word_embed_proj_dim, 1) - value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1)) - super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py deleted file mode 100644 index 1aaef1662..000000000 --- a/applications/Chat/coati/models/utils.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Optional, Union - -import torch -import torch.nn.functional as F - - -def _compute_approx_kl( - log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None -) -> torch.Tensor: - """ - Compute the approximate KL divergence between two distributions. - Schulman blog: http://joschu.net/blog/kl-approx.html - - Args: - log_probs: Log probabilities of the new distribution. - log_probs_base: Log probabilities of the base distribution. - action_mask: Mask for actions. - """ - - log_ratio = log_probs_base - log_probs - approx_kl = (log_ratio.exp() - 1) - log_ratio - if action_mask is not None: - approx_kl = masked_mean(approx_kl, action_mask, dim=1) - return approx_kl - approx_kl = approx_kl.mean(dim=1) - return approx_kl - - -def compute_reward( - r: Union[torch.Tensor, float], - kl_coef: float, - log_probs: torch.Tensor, - log_probs_base: torch.Tensor, - action_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - if kl_coef <= 0.0: - return r - kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) - reward = r - kl_coef * kl - return reward - - -def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - log_probs = F.log_softmax(logits, dim=-1) - log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) - return log_probs_labels.squeeze(-1) - - -def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: - """Calculate action log probs. - - Args: - output (torch.Tensor): Output tensor of Actor.forward.logits. - sequences (torch.LongTensor): Input sequences. - num_actions (int): Number of actions. - - Returns: - torch.Tensor: Action log probs. - """ - log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) - return log_probs[:, -num_actions:] - - -def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: - tensor = tensor * mask - tensor = tensor.sum(dim=dim) - mask_sum = mask.sum(dim=dim) - mean = tensor / (mask_sum + 1e-8) - return mean diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py deleted file mode 100644 index 4be5d27f9..000000000 --- a/applications/Chat/coati/trainer/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .base import OnPolicyTrainer, SLTrainer -from .ppo import PPOTrainer -from .rm import RewardModelTrainer -from .sft import SFTTrainer - -__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"] diff --git a/applications/Chat/coati/trainer/callbacks/__init__.py b/applications/Chat/coati/trainer/callbacks/__init__.py deleted file mode 100644 index 29c8c4f00..000000000 --- a/applications/Chat/coati/trainer/callbacks/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base import Callback -from .performance_evaluator import PerformanceEvaluator -from .save_checkpoint import SaveCheckpoint - -__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"] diff --git a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py deleted file mode 100644 index 0d70b6c53..000000000 --- a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py +++ /dev/null @@ -1,76 +0,0 @@ -import os - -import torch.distributed as dist -from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy -from coati.trainer.utils import is_rank_0 -from torch import nn -from torch.optim import Optimizer - -from .base import Callback - - -class SaveCheckpoint(Callback): - """ - The callback for saving checkpoint for coati. - - Only support saving actor and critic model. - A typical architecture of the saved checkpoint would be: - - checkpoint - - episode_x - - actor.pt - - actor-optim-rank-0.pt - - actor-optim-rank-1.pt - - critic.pt - - critic-optim-rank-0.pt - - critic-optim-rank-1.pt - - ... - - Args: - path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint` - interval(int): the interval episode of saving checkpoint - strategy(Strategy): the strategy used to train - actor(nn.Module): the actor model - critic(nn.Module): the critic model - actor_optim(Optimizer): the optimizer of actor - critic_optim(Optimizer): the optimizer of critic - - """ - - def __init__( - self, - path: str, - interval: int, - strategy: Strategy, - actor: nn.Module = None, - critic: nn.Module = None, - actor_optim: Optimizer = None, - critic_optim: Optimizer = None, - ) -> None: - super().__init__() - self.path = os.path.join(path, "checkpoint") - self.interval = interval - self.strategy = strategy - self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]} - - def on_episode_end(self, episode: int) -> None: - if (episode + 1) % self.interval != 0: - return - base_path = os.path.join(self.path, f"episode_{episode}") - if not os.path.exists(base_path): - os.makedirs(base_path) - - for model in self.model_dict.keys(): - # save model - if self.model_dict[model][0] is None: - # saving only optimizer states is meaningless, so it would be skipped - continue - model_path = os.path.join(base_path, f"{model}.pt") - self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True) - - # save optimizer - if self.model_dict[model][1] is None: - continue - only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)) - rank = 0 if is_rank_0() else dist.get_rank() - optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt") - self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py deleted file mode 100644 index 330e4e0e3..000000000 --- a/applications/Chat/coati/trainer/ppo.py +++ /dev/null @@ -1,202 +0,0 @@ -from typing import Dict, List, Optional - -from coati.experience_buffer import NaiveExperienceBuffer -from coati.experience_maker import Experience, NaiveExperienceMaker -from coati.models.base import Actor, Critic, RewardModel, get_base_model -from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss -from coati.models.utils import calc_action_log_probs -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler -from tqdm import tqdm -from transformers import PreTrainedTokenizerBase - -from colossalai.accelerator import get_accelerator - -from .base import OnPolicyTrainer -from .callbacks import Callback -from .strategies import GeminiStrategy, Strategy -from .utils import CycledDataLoader, is_rank_0, to_device - - -def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict: - unwrapped_model = strategy.unwrap_model(actor) - hf_model = get_base_model(unwrapped_model) - new_kwargs = {**generate_kwargs} - # use huggingface models method directly - if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"): - new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation - - if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"): - new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation - - return new_kwargs - - -class PPOTrainer(OnPolicyTrainer): - """ - Trainer for PPO algorithm. - - Args: - strategy (Strategy): the strategy to use for training - actor (Actor): the actor model in ppo algorithm - critic (Critic): the critic model in ppo algorithm - reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences - initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor - actor_optim (Optimizer): the optimizer to use for actor model - critic_optim (Optimizer): the optimizer to use for critic model - kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss - train_batch_size (int, defaults to 8): the batch size to use for training - buffer_limit (int, defaults to 0): the max_size limitation of buffer - buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu - eps_clip (float, defaults to 0.2): the clip coefficient of policy loss - vf_coef (float, defaults to 1.0): the coefficient of value loss - ptx_coef (float, defaults to 0.9): the coefficient of ptx loss - value_clip (float, defaults to 0.4): the clip coefficient of value loss - sample_buffer (bool, defaults to False): whether to sample from buffer - dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader - offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process - callbacks (List[Callback], defaults to []): the callbacks to call during training process - generate_kwargs (dict, optional): the kwargs to use while model generating - """ - - def __init__( - self, - strategy: Strategy, - actor: Actor, - critic: Critic, - reward_model: RewardModel, - initial_model: Actor, - actor_optim: Optimizer, - critic_optim: Optimizer, - tokenizer: PreTrainedTokenizerBase, - kl_coef: float = 0.1, - ptx_coef: float = 0.9, - train_batch_size: int = 8, - buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - eps_clip: float = 0.2, - vf_coef: float = 1.0, - value_clip: float = 0.4, - sample_buffer: bool = False, - dataloader_pin_memory: bool = True, - offload_inference_models: bool = True, - callbacks: List[Callback] = [], - **generate_kwargs, - ) -> None: - if isinstance(strategy, GeminiStrategy): - assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')" - - data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) - super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks) - - self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) - self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef) - - self.actor = actor - self.critic = critic - self.tokenizer = tokenizer - - self.actor_loss_fn = PolicyLoss(eps_clip) - self.critic_loss_fn = ValueLoss(value_clip) - self.vf_coef = vf_coef - self.ptx_loss_fn = GPTLMLoss() - self.ptx_coef = ptx_coef - self.actor_optim = actor_optim - self.critic_optim = critic_optim - - self.offload_inference_models = offload_inference_models - self.device = get_accelerator().get_current_device() - - def _before_fit( - self, - prompt_dataloader: DataLoader, - pretrain_dataloader: DataLoader, - log_dir: Optional[str] = None, - use_wandb: bool = False, - ): - """ - Args: - prompt_dataloader (DataLoader): the dataloader to use for prompt data - pretrain_dataloader (DataLoader): the dataloader to use for pretrain data - """ - self.prompt_dataloader = CycledDataLoader(prompt_dataloader) - self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) - - self.writer = None - if use_wandb and is_rank_0(): - assert log_dir is not None, "log_dir must be provided when use_wandb is True" - import wandb - - wandb.init(project="Coati-ppo", sync_tensorboard=True) - if log_dir is not None and is_rank_0(): - import os - import time - - from torch.utils.tensorboard import SummaryWriter - - log_dir = os.path.join(log_dir, "ppo") - log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) - self.writer = SummaryWriter(log_dir=log_dir) - - def _make_experience(self, collect_step: int) -> Experience: - prompts = self.prompt_dataloader.next() - if self.offload_inference_models: - # TODO(ver217): this may be controlled by strategy if they are prepared by strategy - self.experience_maker.initial_model.to(self.device) - self.experience_maker.reward_model.to(self.device) - assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"' - return self.experience_maker.make_experience(**prompts, **self.generate_kwargs) - - def _training_step(self, experience: Experience): - self.actor.train() - self.critic.train() - # policy loss - num_actions = experience.action_log_probs.size(1) - actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"] - action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) - actor_loss = self.actor_loss_fn( - action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask - ) - actor_loss = (1 - self.ptx_coef) * actor_loss - self.strategy.backward(actor_loss, self.actor, self.actor_optim) - - # ptx loss - if self.ptx_coef != 0: - batch = self.pretrain_dataloader.next() - batch = to_device(batch, self.device) - ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"] - ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"]) - self.strategy.backward(ptx_loss, self.actor, self.actor_optim) - - self.strategy.optimizer_step(self.actor_optim) - self.actor_optim.zero_grad() - - # value loss - values = self.critic(experience.sequences, attention_mask=experience.attention_mask) - critic_loss = self.critic_loss_fn(values, experience.values, experience.reward) - critic_loss = critic_loss * self.vf_coef - self.strategy.backward(critic_loss, self.critic, self.critic_optim) - self.strategy.optimizer_step(self.critic_optim) - self.critic_optim.zero_grad() - - def _learn(self, update_step: int): - if self.offload_inference_models: - self.experience_maker.initial_model.to("cpu") - self.experience_maker.reward_model.to("cpu") - - # buffer may be empty at first, we should rebuild at each training - if self.sample_buffer: - experience = self.data_buffer.sample() - self._on_learn_batch_start() - experience.to_device(self.device) - self._training_step(experience) - self._on_learn_batch_end(experience) - else: - if isinstance(self.dataloader.sampler, DistributedSampler): - self.dataloader.sampler.set_epoch(update_step) - pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0()) - for experience in pbar: - self._on_learn_batch_start() - experience.to_device(self.device) - self._training_step(experience) - self._on_learn_batch_end(experience) diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py deleted file mode 100644 index d7f8c21a5..000000000 --- a/applications/Chat/coati/trainer/rm.py +++ /dev/null @@ -1,123 +0,0 @@ -from typing import Callable, Optional - -import torch -import tqdm -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader - -from .base import SLTrainer -from .strategies import Strategy -from .utils import is_rank_0 - - -class RewardModelTrainer(SLTrainer): - """ - Trainer to use while training reward model. - - Args: - model (torch.nn.Module): the model to train - strategy (Strategy): the strategy to use for training - optim (Optimizer): the optimizer to use for training - lr_scheduler (_LRScheduler): the lr scheduler to use for training - loss_fn (callable): the loss function to use for training - max_epochs (int, defaults to 2): the number of epochs to train - """ - - def __init__( - self, - model, - strategy: Strategy, - optim: Optimizer, - lr_scheduler: _LRScheduler, - loss_fn: Callable, - max_epochs: int = 1, - ) -> None: - super().__init__(strategy, max_epochs, model, optim) - - self.loss_fn = loss_fn - self.scheduler = lr_scheduler - - self.num_train_step = 0 - - def _eval(self, epoch): - if self.eval_dataloader is not None: - self.model.eval() - dist, num_correct, num_samples = 0, 0, 0 - with torch.no_grad(): - for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader: - chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) - c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) - reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) - r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) - chosen_reward = self.model(chosen_ids, attention_mask=c_mask) - reject_reward = self.model(reject_ids, attention_mask=r_mask) - num_samples += chosen_ids.size(0) - num_correct += (chosen_reward > reject_reward).sum().item() - dist += (chosen_reward - reject_reward).mean().item() - self.dist = dist / len(self.eval_dataloader) - self.acc = num_correct / num_samples - - if self.writer: - self.writer.add_scalar("eval/dist", self.dist, epoch) - self.writer.add_scalar("eval/acc", self.acc, epoch) - - def _train(self, epoch): - self.model.train() - step_bar = tqdm.trange( - len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0() - ) - for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: - chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) - c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) - reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) - r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) - chosen_reward = self.model(chosen_ids, attention_mask=c_mask) - reject_reward = self.model(reject_ids, attention_mask=r_mask) - loss = self.loss_fn(chosen_reward, reject_reward) - self.strategy.backward(loss, self.model, self.optimizer) - self.strategy.optimizer_step(self.optimizer) - self.optimizer.zero_grad() - if self.writer: - self.writer.add_scalar("train/loss", loss.item(), self.num_train_step) - self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) - self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step) - self.writer.add_scalar( - "train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step - ) - self.num_train_step += 1 - if self.num_train_step % 100 == 0: - self.scheduler.step() - step_bar.update() - step_bar.close() - - def _before_fit( - self, - train_dataloader: DataLoader, - eval_dataloader: DataLoader, - log_dir: Optional[str] = None, - use_wandb: bool = False, - ): - """ - Args: - train_dataloader (DataLoader): the dataloader to use for training - eval_dataloader (DataLoader): the dataloader to use for evaluation - """ - self.train_dataloader = train_dataloader - self.eval_dataloader = eval_dataloader - - self.writer = None - if use_wandb and is_rank_0(): - assert log_dir is not None, "log_dir must be provided when use_wandb is True" - import wandb - - wandb.init(project="Coati-rm", sync_tensorboard=True) - if log_dir is not None and is_rank_0(): - import os - import time - - from torch.utils.tensorboard import SummaryWriter - - log_dir = os.path.join(log_dir, "rm") - log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) - self.writer = SummaryWriter(log_dir=log_dir) diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py deleted file mode 100644 index 7d0eeec89..000000000 --- a/applications/Chat/coati/trainer/sft.py +++ /dev/null @@ -1,130 +0,0 @@ -from typing import Optional - -import torch -import torch.distributed as dist -import tqdm -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader - -from colossalai.logging import DistributedLogger - -from .base import SLTrainer -from .strategies import GeminiStrategy, Strategy -from .utils import is_rank_0, to_device - - -class SFTTrainer(SLTrainer): - """ - Trainer to use while training reward model. - - Args: - model (torch.nn.Module): the model to train - strategy (Strategy): the strategy to use for training - optim(Optimizer): the optimizer to use for training - lr_scheduler(_LRScheduler): the lr scheduler to use for training - max_epochs (int, defaults to 2): the number of epochs to train - accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients - """ - - def __init__( - self, - model, - strategy: Strategy, - optim: Optimizer, - lr_scheduler: _LRScheduler, - max_epochs: int = 2, - accumulation_steps: int = 8, - ) -> None: - if accumulation_steps > 1: - assert not isinstance( - strategy, GeminiStrategy - ), "Accumulation steps are not supported in stage 3 of ColossalAI" - - super().__init__(strategy, max_epochs, model, optim) - - self.accumulation_steps = accumulation_steps - self.scheduler = lr_scheduler - - self.num_train_step = 0 - self.num_eval_step = 0 - - def _train(self, epoch: int): - self.model.train() - step_bar = tqdm.trange( - len(self.train_dataloader) // self.accumulation_steps, - desc=f"Epoch {epoch + 1}/{self.max_epochs}", - disable=not is_rank_0(), - ) - for i, batch in enumerate(self.train_dataloader): - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) - loss = outputs.loss / self.accumulation_steps - self.total_loss += loss.item() - self.strategy.backward(loss, self.model, self.optimizer) - # gradient accumulation - if (i + 1) % self.accumulation_steps == 0: - self.strategy.optimizer_step(self.optimizer) - self.optimizer.zero_grad() - self.scheduler.step() - if self.writer: - self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step) - self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) - self.num_train_step += 1 - self.total_loss = 0 - step_bar.update() - step_bar.close() - - def _eval(self, epoch: int): - if self.eval_dataloader is not None: - self.model.eval() - with torch.no_grad(): - loss_sum, num_seen = 0, 0 - for batch in self.eval_dataloader: - batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model( - batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] - ) - loss_sum += outputs.loss.item() - num_seen += batch["input_ids"].size(0) - loss_mean = loss_sum / num_seen - if dist.get_rank() == 0: - self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}") - if self.writer: - self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step) - self.num_eval_step += 1 - - def _before_fit( - self, - train_dataloader: DataLoader, - eval_dataloader: Optional[DataLoader] = None, - logger: Optional[DistributedLogger] = None, - log_dir: Optional[str] = None, - use_wandb: bool = False, - ): - """ - Args: - train_dataloader: the dataloader to use for training - eval_dataloader: the dataloader to use for evaluation - """ - self.train_dataloader = train_dataloader - self.eval_dataloader = eval_dataloader - - self.logger = logger - self.writer = None - if use_wandb and is_rank_0(): - assert log_dir is not None, "log_dir must be provided when use_wandb is True" - import wandb - - wandb.init(project="Coati-sft", sync_tensorboard=True) - if log_dir is not None and is_rank_0(): - import os - import time - - from torch.utils.tensorboard import SummaryWriter - - log_dir = os.path.join(log_dir, "sft") - log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) - self.writer = SummaryWriter(log_dir=log_dir) - - self.total_loss = 0 diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py deleted file mode 100644 index 521dcb585..000000000 --- a/applications/Chat/coati/trainer/strategies/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base import Strategy -from .colossalai import GeminiStrategy, LowLevelZeroStrategy -from .ddp import DDPStrategy - -__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"] diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py deleted file mode 100644 index a78716216..000000000 --- a/applications/Chat/coati/trainer/strategies/base.py +++ /dev/null @@ -1,137 +0,0 @@ -from abc import ABC, abstractmethod -from contextlib import nullcontext -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from coati.experience_buffer import ExperienceBuffer -from torch.optim import Optimizer -from torch.utils.data import DataLoader -from transformers.tokenization_utils_base import PreTrainedTokenizerBase - -from colossalai.booster import Booster -from colossalai.booster.plugin import Plugin - -from .sampler import DistributedSampler - -_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict] - - -class Strategy(ABC): - """ - Base class for training strategies. - """ - - def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None: - super().__init__() - # NOTE: dist must be initialized before Booster - self.setup_distributed() - self.plugin = plugin_initializer() - self.booster = Booster(plugin=self.plugin) - self._post_init() - - @abstractmethod - def _post_init(self) -> None: - pass - - def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None: - self.booster.backward(loss, optimizer) - - def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None: - optimizer.step() - - @abstractmethod - def setup_distributed(self) -> None: - pass - - @abstractmethod - def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: - pass - - def model_init_context(self): - return nullcontext() - - def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]: - """Prepare [model | (model, optimizer) | Dict] based on each strategy. - NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments. - - Example:: - >>> # e.g., include lr_scheduler - >>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler)) - >>> # when fine-tuning actor and critic - >>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model) - >>> # or when training reward model - >>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim)) - >>> # or just inference - >>> actor, critic = strategy.prepare(actor, critic) - - Returns: - Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order. - """ - - rets = [] - for arg in boost_args: - if isinstance(arg, nn.Module): - model, *_ = self.booster.boost(arg) - rets.append(model) - elif isinstance(arg, tuple): - try: - model, optimizer = arg - except ValueError: - raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"') - model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer) - rets.append((model, optimizer)) - elif isinstance(arg, Dict): - model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg) - boost_result = dict( - model=model, - optimizer=optimizer, - criterion=criterion, - dataloader=dataloader, - lr_scheduler=lr_scheduler, - ) - # remove None values - boost_result = {key: value for key, value in boost_result.items() if value is not None} - rets.append(boost_result) - else: - raise RuntimeError(f"Type {type(arg)} is not supported") - - return rets[0] if len(rets) == 1 else rets - - @staticmethod - def unwrap_model(model: nn.Module) -> nn.Module: - """Get the unwrapped model from a wrapped model made by Strategy.prepare. - - Args: - model (nn.Module): the model to unwrap - - Returns: - nn.Module: the original model - """ - return model - - def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None: - self.booster.save_model(model, path, shard=shard, **kwargs) - - def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None: - self.booster.load_model(model, path, strict) - - def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None: - self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs) - - def load_optimizer(self, optimizer: Optimizer, path: str) -> None: - self.booster.load_optimizer(optimizer, path) - - def setup_sampler(self, dataset) -> DistributedSampler: - # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API. - return DistributedSampler(dataset, 1, 0) - - @abstractmethod - def save_pretrained( - self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None - ) -> None: - pass - - @abstractmethod - def get_model_state_dict_shard(self, model: nn.Module, **config): - pass diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py deleted file mode 100644 index 95f016786..000000000 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ /dev/null @@ -1,209 +0,0 @@ -import warnings -from typing import Optional - -import torch.nn as nn - -import colossalai -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.zero.gemini.gemini_ddp import GeminiDDP - -from .ddp import DDPStrategy - - -class LowLevelZeroStrategy(DDPStrategy): - """ - The strategy for training with ColossalAI. - - Args: - stage(int): The stage to use in ZeRO. Choose in (1, 2) - precision(str): The precision to use. Choose in ('fp32', 'fp16'). - seed(int): The seed for the random number generator. - placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda') - If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU, - If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest. - reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2. - overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2. - initial_scale(float): The initial scale for the optimizer. - growth_factor(float): The growth factor for the optimizer. - backoff_factor(float): The backoff factor for the optimizer. - growth_interval(int): The growth interval for the optimizer. - hysteresis(int): The hysteresis for the optimizer. - min_scale(float): The minimum scale for the optimizer. - max_scale(float): The maximum scale for the optimizer. - max_norm(float): The maximum norm for the optimizer. - norm_type(float): The norm type for the optimizer. - - """ - - def __init__( - self, - stage: int = 2, - precision: str = "fp16", - seed: int = 42, - placement_policy: str = "cuda", - reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 - overlap_communication: bool = True, # only for stage 1&2 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - ) -> None: - assert stage in (1, 2), f'Unsupported stage "{stage}"' - assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"' - assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"' - - plugin_initializer = lambda: LowLevelZeroPlugin( - stage=stage, - precision=precision, - reduce_bucket_size_in_m=reduce_bucket_size, - overlap_communication=overlap_communication, - cpu_offload=(placement_policy == "cpu"), - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type, - ) - - super().__init__(seed, plugin_initializer) - - def _post_init(self) -> None: - assert isinstance( - self.plugin, LowLevelZeroPlugin - ), f"{type(self).__name__}'s plugin is not initialized properly." - - def setup_distributed(self) -> None: - colossalai.launch_from_torch({}, seed=self.seed) - - def unwrap_model(self, model: nn.Module) -> nn.Module: - assert isinstance(model, LowLevelZeroModel) - return model.module - - def get_model_state_dict_shard(self, model: nn.Module, **config): - assert isinstance(model, LowLevelZeroModel) - yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False) - - -class GeminiStrategy(DDPStrategy): - """ - The strategy for training with ColossalAI. - - Args: - seed(int): The seed for the random number generator. - shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3. - This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future. - placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda') - If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU, - If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest. - pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3. - force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3. - search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3. - hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3. - min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3. - gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3. - initial_scale(float): The initial scale for the optimizer. - growth_factor(float): The growth factor for the optimizer. - backoff_factor(float): The backoff factor for the optimizer. - growth_interval(int): The growth interval for the optimizer. - hysteresis(int): The hysteresis for the optimizer. - min_scale(float): The minimum scale for the optimizer. - max_scale(float): The maximum scale for the optimizer. - max_norm(float): The maximum norm for the optimizer. - norm_type(float): The norm type for the optimizer. - - """ - - def __init__( - self, - seed: int = 42, - shard_init: bool = False, # only for stage 3 - placement_policy: str = "auto", - shard_param_frac: float = 1.0, # only for static placement - offload_optim_frac: float = 0.0, # only for static placement - offload_param_frac: float = 0.0, # only for static placement - pin_memory: bool = True, # only for stage 3 - force_outputs_fp32: bool = False, # only for stage 3 - search_range_m: int = 32, # only for stage 3 - hidden_dim: Optional[int] = None, # only for stage 3 - min_chunk_size_m: float = 32, # only for stage 3 - gpu_margin_mem_ratio: float = 0.0, # only for stage 3 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - ) -> None: - # TODO(ver217): support shard_init when using from_pretrained() - if shard_init: - warnings.warn( - f"Shard init is not supported model.from_pretrained() yet. " - "Please load weights after strategy.prepare()" - ) - self.shard_init = shard_init - - warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") - - # colossalai has changed api for get_current_device in 0.3.4 version or newer - try: - from colossalai.accelerator import get_accelerator - - chunk_init_device = get_accelerator().get_current_device() - except: - from colossalai.utils import get_current_device - - chunk_init_device = get_current_device() - - # NOTE: dist should be initialized before calling get_current_device() - plugin_initializer = lambda: GeminiPlugin( - chunk_init_device=chunk_init_device, - placement_policy=placement_policy, - shard_param_frac=shard_param_frac, - offload_optim_frac=offload_optim_frac, - offload_param_frac=offload_param_frac, - precision="fp16", - pin_memory=pin_memory, - force_outputs_fp32=force_outputs_fp32, - strict_ddp_mode=shard_init, - search_range_m=search_range_m, - hidden_dim=hidden_dim, - min_chunk_size_m=min_chunk_size_m, - gpu_margin_mem_ratio=gpu_margin_mem_ratio, - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type, - ) - - super().__init__(seed, plugin_initializer) - - def _post_init(self) -> None: - assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly." - - def setup_distributed(self) -> None: - colossalai.launch_from_torch({}, seed=self.seed) - - def model_init_context(self): - return super().model_init_context() - - def unwrap_model(self, model: nn.Module) -> nn.Module: - assert isinstance(model, GeminiDDP) - return model.module diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py deleted file mode 100644 index f2a44aeb0..000000000 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import random -from collections import OrderedDict -from typing import Callable, Optional - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -from coati.experience_buffer import ExperienceBuffer -from coati.models import Actor, Critic, RewardModel -from torch.utils.data import DataLoader -from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils_base import PreTrainedTokenizerBase - -from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel - -from .base import Strategy -from .sampler import DistributedSampler - - -# TODO Move this to a util.py (Moving to ray.util introduces ringed import) -def get_grad_required_state_dict(model: nn.Module): - state_dict = OrderedDict() - for name, parameter in model.named_parameters(): - if parameter.requires_grad: - state_dict[name] = parameter.detach() - return state_dict - - -class DDPStrategy(Strategy): - """ - Strategy for distributed training using torch.distributed. - """ - - def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None: - self.seed = seed - super().__init__(plugin_initializer) - - def _try_init_dist(self, force: bool = False) -> None: - try: - rank = int(os.environ["RANK"]) - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - host = os.environ["MASTER_ADDR"] - port = int(os.environ["MASTER_PORT"]) - dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank) - torch.cuda.set_device(local_rank) - except KeyError as e: - if force: - raise RuntimeError( - f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" - ) - except Exception as e: - if force: - raise e - - def _post_init(self) -> None: - assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly." - - def setup_distributed(self) -> None: - self._try_init_dist(force=True) - self.set_seed(self.seed) - - def set_seed(self, seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - - def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: - return self.plugin.prepare_dataloader( - data_buffer, - batch_size=data_buffer.sample_batch_size, - shuffle=True, - drop_last=True, - pin_memory=pin_memory, - collate_fn=data_buffer.collate_fn, - ) - - def setup_sampler(self, dataset) -> DistributedSampler: - # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API. - return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) - - def unwrap_model(self, model: nn.Module) -> nn.Module: - assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel." - return model.unwrap() - - def save_pretrained( - self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None - ) -> None: - if dist.get_rank() == 0: - unwrapped_model = self.unwrap_model(model) - assert isinstance(unwrapped_model, (Actor, Critic, RewardModel)) - pretrained_model = unwrapped_model.model - assert isinstance(pretrained_model, PreTrainedModel) - # HACK: only use hf save_pretrained to save config - pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None) - if tokenizer is not None: - tokenizer.save_pretrained(path) - - model_path = os.path.join(path, "pytorch_model.bin") - self.save_model(model, model_path, shard=shard) - def _replace_keys(model_path: str, replace_fn: Callable): - state_dict = torch.load(model_path, map_location="cpu") - state_dict = {replace_fn(k): v for k, v in state_dict.items()} - torch.save(state_dict, model_path) - # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin - # HACK: rename keys of pytorch_model.bin - if dist.get_rank() == 0: - _replace_keys(model_path, lambda k: k.replace("model.", "", 1)) - - - def get_model_state_dict_shard(self, model: nn.Module, **config): - # TODO: implement sharding on naive strategy - model = self.unwrap_model(model) - if "requires_grad_only" in config and config["requires_grad_only"] == True: - state_dict = get_grad_required_state_dict(model) - else: - state_dict = model.state_dict() - - if "shard_size" in config: - shard_size = config["shard_size"] - accumulate_size = 0 - state_dict_shard = OrderedDict() - for name, param in state_dict.items(): - state_dict_shard[name] = param - accumulate_size += param.numel() * param.element_size() - if accumulate_size >= shard_size: - accumulate_size = 0 - yield state_dict_shard - state_dict_shard = OrderedDict() - if accumulate_size > 0: - yield state_dict_shard - else: - yield state_dict diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py deleted file mode 100644 index 6e811bef1..000000000 --- a/applications/Chat/coati/trainer/strategies/sampler.py +++ /dev/null @@ -1,31 +0,0 @@ -import math - -import numpy as np - - -class DistributedSampler: - def __init__(self, dataset, num_replicas: int, rank: int) -> None: - self.dataset = dataset - self.num_replicas = num_replicas - self.rank = rank - - if len(self.dataset) % self.num_replicas != 0: - self.num_samples = math.ceil( - (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] - ) - else: - self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) - - self.total_size = self.num_samples * self.num_replicas - - indices = list(range(len(self.dataset))) - indices = indices[: self.total_size] - assert len(indices) == self.total_size - # subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - assert len(indices) == self.num_samples - self.indices = indices - - def sample(self, batch_size: int) -> list: - sampled_indices = np.random.choice(self.indices, batch_size, replace=False) - return [self.dataset[idx] for idx in sampled_indices] diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py deleted file mode 100644 index 7811e7365..000000000 --- a/applications/Chat/coati/trainer/utils.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Any - -import torch -import torch.distributed as dist -from torch.utils._pytree import tree_map -from torch.utils.data import DataLoader - - -class CycledDataLoader: - """ - Why do we need this class? - In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain. - However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...) - NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior. - """ - - def __init__( - self, - dataloader: DataLoader, - ) -> None: - self.dataloader = dataloader - - self.count = 0 - self.dataloader_iter = None - - def next(self): - # defer initialization - if self.dataloader_iter is None: - self.dataloader_iter = iter(self.dataloader) - - self.count += 1 - try: - return next(self.dataloader_iter) - except StopIteration: - self.count = 0 - self.dataloader_iter = iter(self.dataloader) - return next(self.dataloader_iter) - - -def is_rank_0() -> bool: - return not dist.is_initialized() or dist.get_rank() == 0 - - -def to_device(x: Any, device: torch.device) -> Any: - def _to(t: Any): - if isinstance(t, torch.Tensor): - return t.to(device) - return t - - return tree_map(_to, x) diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md deleted file mode 100644 index 9438aafd1..000000000 --- a/applications/Chat/examples/README.md +++ /dev/null @@ -1,409 +0,0 @@ -# Examples - -## Table of Contents - -- [Examples](#examples) - - [Table of Contents](#table-of-contents) - - [Install requirements](#install-requirements) - - [Supervised datasets collection](#supervised-datasets-collection) - - [Conversation dataset generation](#conversation-dataset-generation) - - [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning) - - [Arg List](#arg-list) - - [Stage2 - Training reward model](#stage2---training-reward-model) - - [Features and tricks in RM training](#features-and-tricks-in-rm-training) - - [Experiment result](#experiment-result) - - [Arg List](#arg-list-1) - - [Stage3 - Training model using prompts with RL](#stage3---training-model-using-prompts-with-rl) - - [Arg List](#arg-list-2) - - [Inference example - After Stage3](#inference-example---after-stage3) - - [Attention](#attention) - - [data](#data) - - [Support Model](#support-model) - - [GPT](#gpt) - - [BLOOM](#bloom) - - [OPT](#opt) - - [LLaMA](#llama) - - [Add your own models](#add-your-own-models) - - [Actor model](#actor-model) - - [Reward model](#reward-model) - - [Critic model](#critic-model) - ---- - -## Install requirements - -```shell -pip install -r requirements.txt -``` - -## Supervised datasets collection - -We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo -[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md). - -Here is how we collected the data - -

                  - -

                  - -### Conversation dataset generation - -In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat. - -A sample of conversation dataset should have the following fields: - -- `type` (str, optional): The type of the data sample. -- `language` (str, optional): The language of the data sample. -- `dataset` (str, optional): The dataset the data sample originates from. -- `conversations` (str, compulsory): Conversation content of the data sample. -- `id` (int, optional): The ID of the data sample. - -A simple example: - -```json -{ - "type": "instruction", - "language": "English", - "dataset": "Alpaca", - "conversations": [ - { - "from": "human", - "value": "Give three tips for staying healthy." - }, - { - "from": "gpt", - "value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule." - } - ], - "id": 1 -} -``` - -> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies. - -You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat. - -You can use the following cmd to generate conversation dataset. - -```bash -python generate_conversation_dataset.py \ - --dataset "All" - --save_path "/path/to/dataset" -``` - -## Stage1 - Supervised instructs tuning - -Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model. -[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg) - -You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning. - -You can also use the following cmd to start a supervised instructs fine-tuning with your own settings. - -```bash -torchrun --standalone --nproc_per_node=4 train_sft.py \ - --pretrain "/path/to/LLaMa-7B/" \ - --model 'llama' \ - --strategy colossalai_zero2 \ - --save_path /path/to/Coati-7B \ - --dataset /path/to/data.json \ - --batch_size 4 \ - --accumulation_steps 8 \ - --lr 2e-5 \ - --max_datasets_size 512 \ - --max_epochs 1 \ - --grad_checkpoint -``` - -**Note**: the supervised dataset follows the following format, - -```json -[ - { - "instruction": "Provide a list of the top 10 most popular mobile games in Asia", - "input": "", - "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": 0 - }, - ... -] -``` - -### Arg List - -- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' -- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' -- `--pretrain`: pretrain model, type=str, default=None -- `--max_datasets_size`: the max size of dataset, type=int, default=None -- `--save_path`: path to save the model, type=str, default='output' -- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False -- `--max_epochs`: max epochs for training, type=int, default=3 -- `--batch_size`: batch size while training, type=int, default=4 -- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0 -- `--grad_checkpoint`: enable gradient checkpointing, type=bool, default=False - -## Stage2 - Training reward model - -We train a reward model in stage 2, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model. -[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo) - -You can run the `examples/train_rm.sh` to start a reward model training. - -You can also use the following cmd to start training a reward model. - -```bash -torchrun --standalone --nproc_per_node=4 train_reward_model.py \ - --pretrain "/path/to/LLaMa-7B/" \ - --model 'llama' \ - --strategy colossalai_zero2 \ - --loss_fn 'log_exp'\ - --save_path 'rmstatic.pt' \ -``` - -### Features and tricks in RM training - -- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets. -- We support 2 kinds of loss function named `log_sig`(used by OpenAI) and `log_exp`(used by Anthropic). -- We change the loss to `valid_acc` and `pair_dist` to monitor progress during training. -- We add special token to the end of the sequence to get better result. -- We use cosine-reducing lr-scheduler for RM training. -- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution. -- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2204.05862). - -### Experiment result - -Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862): - -
                  image - -
                  Our training & test result of bloom-560m for 1 epoch: - -
                  image - -
                  We also train the reward model based on LLaMA-7B, which reaches the ACC of 72.06% after 1 epoch, performing almost the same as Anthropic's best RM. - -### Arg List - -- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' -- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' -- `--pretrain`: pretrain model, type=str, default=None -- `--model_path`: the path of rm model(if continue to train), type=str, default=None -- `--save_path`: path to save the model, type=str, default='output' -- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False -- `--max_epochs`: max epochs for training, type=int, default=3 -- `--dataset`: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'] -- `--subset`: subset of the dataset, type=str, default=None -- `--batch_size`: batch size while training, type=int, default=4 -- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0 -- `--loss_func`: which kind of loss function, choices=['log_sig', 'log_exp'] -- `--max_len`: max sentence length for generation, type=int, default=512 - -## Stage3 - Training model using prompts with RL - -Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process, as shown below: - -

                  - -

                  - -You can run the `examples/train_prompts.sh` to start PPO training. - -You can also use the cmd following to start PPO training. -[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g) - -```bash -torchrun --standalone --nproc_per_node=4 train_prompts.py \ - --pretrain "/path/to/LLaMa-7B/" \ - --model 'llama' \ - --strategy colossalai_zero2 \ - --prompt_dataset /path/to/your/prompt_dataset \ - --pretrain_dataset /path/to/your/pretrain_dataset \ - --rm_pretrain /your/pretrain/rm/definition \ - --rm_path /your/rm/model/path -``` - -Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset. -Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning. - -**Note**: the required datasets follow the following format, - -- `pretrain dataset` - - ```json - [ - { - "instruction": "Provide a list of the top 10 most popular mobile games in Asia", - "input": "", - "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": 0 - }, - ... - ] - ``` - -- `prompt dataset` - - ```json - [ - { - "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", - "id": 0 - }, - { - "instruction": "Write a descriptive paragraph about a memorable vacation you went on", - "id": 1 - }, - ... - ] - ``` - -### Arg List - -- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2' -- `--model`: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom' -- `--pretrain`: pretrain model, type=str, default=None -- `--rm_model`: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None -- `--rm_pretrain`: pretrain model for reward model, type=str, default=None -- `--rm_path`: the path of rm model, type=str, default=None -- `--save_path`: path to save the model, type=str, default='output' -- `--prompt_dataset`: path of the prompt dataset, type=str, default=None -- `--pretrain_dataset`: path of the ptx dataset, type=str, default=None -- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False -- `--num_episodes`: num of episodes for training, type=int, default=10 -- `--num_update_steps`: number of steps to update policy per episode, type=int -- `--num_collect_steps`: number of steps to collect experience per episode, type=int -- `--train_batch_size`: batch size while training, type=int, default=8 -- `--ptx_batch_size`: batch size to compute ptx loss, type=int, default=1 -- `--experience_batch_size`: batch size to make experience, type=int, default=8 -- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0 -- `--kl_coef`: kl_coef using for computing reward, type=float, default=0.1 -- `--ptx_coef`: ptx_coef using for computing policy loss, type=float, default=0.9 - -## Inference example - After Stage3 - -We support different inference options, including int8 and int4 quantization. -For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). - -## Attention - -The examples are demos for the whole training process.You need to change the hyper-parameters to reach great performance. - -#### data - -- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) -- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) -- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback) -- [ ] [openai/webgpt_comparisons](https://huggingface.co/datasets/openai/webgpt_comparisons) -- [ ] [Dahoas/instruct-synthetic-prompt-responses](https://huggingface.co/datasets/Dahoas/instruct-synthetic-prompt-responses) - -## Support Model - -### GPT - -- [x] GPT2-S (s) -- [x] GPT2-M (m) -- [x] GPT2-L (l) -- [x] GPT2-XL (xl) -- [x] GPT2-4B (4b) -- [ ] GPT2-6B (6b) - -### BLOOM - -- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m) -- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1) -- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b) -- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloom-7b1) -- [ ] [BLOOM-175b](https://huggingface.co/bigscience/bloom) - -### OPT - -- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m) -- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m) -- [x] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b) -- [x] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b) -- [x] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b) -- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b) -- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b) - -### [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) - -- [x] LLaMA-7B -- [x] LLaMA-13B -- [ ] LLaMA-33B -- [ ] LLaMA-65B - -## Add your own models - -If you want to support your own model in Coati, please refer the pull request for RoBERTa support as an example --[[chatgpt] add pre-trained model RoBERTa for RLHF stage 2 & 3](https://github.com/hpcaitech/ColossalAI/pull/3223), and submit a PR to us. - -You should complete the implementation of four model classes, including Reward model, Critic model, LM model, Actor model - -here are some example code for a NewModel named `Coati`. -if it is supported in huggingface [transformers](https://github.com/huggingface/transformers), you can load it by `from_pretrained`, o -r you can build your own model by yourself. - -### Actor model - -```python -from ..base import Actor -from transformers.models.coati import CoatiModel - -class CoatiActor(Actor): - def __init__(self, - pretrained: Optional[str] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - if pretrained is not None: - model = CoatiModel.from_pretrained(pretrained) - else: - model = build_model() # load your own model if it is not support in transformers - - super().__init__(model, lora_rank, lora_train_bias) -``` - -### Reward model - -```python -from ..base import RewardModel -from transformers.models.coati import CoatiModel - -class CoatiRM(RewardModel): - - def __init__(self, - pretrained: Optional[str] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - if pretrained is not None: - model = CoatiModel.from_pretrained(pretrained) - else: - model = build_model() # load your own model if it is not support in transformers - - value_head = nn.Linear(model.config.n_embd, 1) - value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) - super().__init__(model, value_head, lora_rank, lora_train_bias) -``` - -### Critic model - -```python -from ..base import Critic -from transformers.models.coati import CoatiModel - -class CoatiCritic(Critic): - def __init__(self, - pretrained: Optional[str] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - if pretrained is not None: - model = CoatiModel.from_pretrained(pretrained) - else: - model = build_model() # load your own model if it is not support in transformers - - value_head = nn.Linear(model.config.n_embd, 1) - value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) - super().__init__(model, value_head, lora_rank, lora_train_bias) -``` diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py deleted file mode 100644 index ec3482b5f..000000000 --- a/applications/Chat/examples/download_model.py +++ /dev/null @@ -1,79 +0,0 @@ -import argparse -import dataclasses -import os -import parser -from typing import List - -import tqdm -from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic -from coati.models.gpt import GPTRM, GPTActor, GPTCritic -from coati.models.opt import OPTRM, OPTActor, OPTCritic -from huggingface_hub import hf_hub_download, snapshot_download -from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer - - -@dataclasses.dataclass -class HFRepoFiles: - repo_id: str - files: List[str] - - def download(self, dir_path: str): - for file in self.files: - file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path) - - def download_all(self): - snapshot_download(self.repo_id) - - -def test_init(model: str, dir_path: str): - if model == "gpt2": - config = GPT2Config.from_pretrained(dir_path) - actor = GPTActor(config=config) - critic = GPTCritic(config=config) - reward_model = GPTRM(config=config) - GPT2Tokenizer.from_pretrained(dir_path) - elif model == "bloom": - config = BloomConfig.from_pretrained(dir_path) - actor = BLOOMActor(config=config) - critic = BLOOMCritic(config=config) - reward_model = BLOOMRM(config=config) - BloomTokenizerFast.from_pretrained(dir_path) - elif model == "opt": - config = AutoConfig.from_pretrained(dir_path) - actor = OPTActor(config=config) - critic = OPTCritic(config=config) - reward_model = OPTRM(config=config) - AutoTokenizer.from_pretrained(dir_path) - else: - raise NotImplementedError(f"Model {model} not implemented") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model-dir", type=str, default="test_models") - parser.add_argument("--config-only", default=False, action="store_true") - args = parser.parse_args() - - if os.path.exists(args.model_dir): - print(f"[INFO]: {args.model_dir} already exists") - exit(0) - - repo_list = { - "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]), - "bloom": HFRepoFiles( - repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"] - ), - "opt": HFRepoFiles( - repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"] - ), - } - - os.mkdir(args.model_dir) - for model_name in tqdm.tqdm(repo_list): - dir_path = os.path.join(args.model_dir, model_name) - if args.config_only: - os.mkdir(dir_path) - repo_list[model_name].download(dir_path) - else: - repo_list[model_name].download_all() - test_init(model_name, dir_path) diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py deleted file mode 100644 index 7e03b2d54..000000000 --- a/applications/Chat/examples/generate_conversation_dataset.py +++ /dev/null @@ -1,82 +0,0 @@ -import argparse -import json - -from datasets import load_dataset - - -def generate_alpaca(): - # We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation. - conversation_dataset = [] - dataset = load_dataset("tatsu-lab/alpaca", split="train") - - instructions = dataset["instruction"] - inputs = dataset["input"] - outputs = dataset["output"] - - assert len(instructions) == len(inputs) == len(outputs) - - for idx in range(len(instructions)): - human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx] - human = {"from": "human", "value": human_utterance} - - gpt_utterance = outputs[idx] - gpt = {"from": "gpt", "value": gpt_utterance} - - conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt]) - conversation_dataset.append(conversation) - - return conversation_dataset - - -def generate_sharegpt(): - # ShareGPT data requires less processing. - conversation_dataset = [] - dataset = load_dataset( - "anon8231489123/ShareGPT_Vicuna_unfiltered", - data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json", - split="train", - ) - - conversations = dataset["conversations"] - - for idx in range(len(conversations)): - for conv in conversations[idx]: - # We don't need markdown and text value. - del conv["markdown"] - del conv["text"] - - conversation = dict( - type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx] - ) - conversation_dataset.append(conversation) - - return conversation_dataset - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dataset", - type=str, - default="All", - choices=["Alpaca", "ShareGPT", "All"], - help="which dataset to convert, All will combine Alpaca and ShareGPT", - ) - parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset") - args = parser.parse_args() - - conversation_dataset = [] - - if args.dataset == "Alpaca": - conversation_dataset.extend(generate_alpaca()) - elif args.dataset == "ShareGPT": - conversation_dataset.extend(generate_sharegpt()) - else: - conversation_dataset.extend(generate_alpaca()) - conversation_dataset.extend(generate_sharegpt()) - - for idx, sample in enumerate(conversation_dataset): - sample["id"] = idx + 1 - - with open(args.save_path, mode="w") as f: - json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False) diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py deleted file mode 100644 index 4eec6feae..000000000 --- a/applications/Chat/examples/generate_prompt_dataset.py +++ /dev/null @@ -1,27 +0,0 @@ -import argparse -import json -import random - -random.seed(42) - - -def sample(args): - with open(args.dataset_path, mode="r") as f: - dataset_list = json.load(f) - - sampled_dataset = [ - {"instruction": sample["instruction"], "id": idx} - for idx, sample in enumerate(random.sample(dataset_list, args.sample_size)) - ] - - with open(args.save_path, mode="w") as f: - json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset") - parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset") - parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset") - args = parser.parse_args() - sample(args) diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py deleted file mode 100644 index 9df8649d9..000000000 --- a/applications/Chat/examples/inference.py +++ /dev/null @@ -1,73 +0,0 @@ -import argparse - -import torch -from coati.models.bloom import BLOOMActor -from coati.models.generation import generate -from coati.models.gpt import GPTActor -from coati.models.llama import LlamaActor -from coati.models.opt import OPTActor -from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer - - -def eval(args): - # configure model - if args.model == "gpt2": - actor = GPTActor(pretrained=args.pretrain) - elif args.model == "bloom": - actor = BLOOMActor(pretrained=args.pretrain) - elif args.model == "opt": - actor = OPTActor(pretrained=args.pretrain) - elif args.model == "llama": - actor = LlamaActor(pretrained=args.pretrain) - else: - raise ValueError(f'Unsupported model "{args.model}"') - - actor.to(torch.cuda.current_device()) - if args.model_path is not None: - state_dict = torch.load(args.model_path) - actor.load_state_dict(state_dict) - - # configure tokenizer - if args.model == "gpt2": - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "bloom": - tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "opt": - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "llama": - tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer.eos_token = "" - tokenizer.pad_token = tokenizer.unk_token - else: - raise ValueError(f'Unsupported model "{args.model}"') - - actor.eval() - tokenizer.padding_side = "left" - input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device()) - outputs = generate( - actor, - input_ids, - tokenizer=tokenizer, - max_length=args.max_length, - do_sample=True, - top_k=50, - top_p=0.95, - num_return_sequences=1, - ) - output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) - print(f"[Output]: {''.join(output)}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) - # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model - parser.add_argument("--pretrain", type=str, default=None) - parser.add_argument("--model_path", type=str, default=None) - parser.add_argument("--input", type=str, default="Question: How are you ? Answer:") - parser.add_argument("--max_length", type=int, default=100) - args = parser.parse_args() - eval(args) diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py deleted file mode 100644 index 40e06043a..000000000 --- a/applications/Chat/examples/train_prompts.py +++ /dev/null @@ -1,249 +0,0 @@ -import argparse -import warnings - -import torch -import torch.distributed as dist -from coati.dataset import PromptDataset, SupervisedDataset -from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic -from coati.models.gpt import GPTRM, GPTActor, GPTCritic -from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM -from coati.models.opt import OPTRM, OPTActor, OPTCritic -from coati.trainer import PPOTrainer -from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy -from torch.optim import Adam -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer - -from colossalai.nn.optimizer import HybridAdam - - -def main(args): - # configure strategy - if args.strategy == "ddp": - strategy = DDPStrategy() - elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5) - elif args.strategy == "colossalai_zero2": - strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") - else: - raise ValueError(f'Unsupported strategy "{args.strategy}"') - - if args.rm_path is not None: - warnings.warn("LoRA weights should be merged with the model weights") - state_dict = torch.load(args.rm_path, map_location="cpu") - - if args.lora_rank > 0: - warnings.warn("Lora is not supported yet.") - args.lora_rank = 0 - - with strategy.model_init_context(): - # configure model - if args.model == "gpt2": - initial_model = GPTActor(pretrained=args.pretrain) - elif args.model == "bloom": - initial_model = BLOOMActor(pretrained=args.pretrain) - elif args.model == "opt": - initial_model = OPTActor(pretrained=args.pretrain) - elif args.model == "llama": - initial_model = LlamaActor(pretrained=args.pretrain) - else: - raise ValueError(f'Unsupported actor model "{args.model}"') - - if args.rm_model is None: - rm_model_name = args.model - else: - rm_model_name = args.rm_model - - if rm_model_name == "gpt2": - reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == "bloom": - reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == "opt": - reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == "llama": - reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - else: - raise ValueError(f'Unsupported reward model "{rm_model_name}"') - - if args.rm_path is not None: - reward_model.load_state_dict(state_dict, strict=False) - - initial_model.to(torch.bfloat16).to(torch.cuda.current_device()) - reward_model.to(torch.bfloat16).to(torch.cuda.current_device()) - - if args.model == "gpt2": - actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == "bloom": - actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == "opt": - actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == "llama": - actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - else: - raise ValueError(f'Unsupported actor model "{args.model}"') - - if rm_model_name == "gpt2": - critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == "bloom": - critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == "opt": - critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == "llama": - critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - else: - raise ValueError(f'Unsupported reward model "{rm_model_name}"') - - if args.rm_path is not None: - critic.load_state_dict(state_dict, strict=False) - del state_dict - - actor.to(torch.bfloat16).to(torch.cuda.current_device()) - critic.to(torch.bfloat16).to(torch.cuda.current_device()) - - # configure optimizer - if args.strategy.startswith("colossalai"): - actor_optim = HybridAdam(actor.parameters(), lr=args.lr) - critic_optim = HybridAdam(critic.parameters(), lr=args.lr) - else: - actor_optim = Adam(actor.parameters(), lr=args.lr) - critic_optim = Adam(critic.parameters(), lr=args.lr) - - # configure tokenizer - if args.model == "gpt2": - tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "bloom": - tokenizer = BloomTokenizerFast.from_pretrained( - "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer - ) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "opt": - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "llama": - tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer - ) - tokenizer.eos_token = "" - tokenizer.pad_token = tokenizer.unk_token - else: - raise ValueError(f'Unsupported model "{args.model}"') - # NOTE: generate() requires padding_side to be "left" - tokenizer.padding_side = "left" - - prompt_dataset = PromptDataset( - tokenizer=tokenizer, - data_path=args.prompt_dataset, - max_datasets_size=args.max_datasets_size, - max_length=args.max_input_len, - ) - if dist.is_initialized() and dist.get_world_size() > 1: - prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) - else: - prompt_sampler = None - prompt_dataloader = DataLoader( - prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size - ) - - pretrain_dataset = SupervisedDataset( - tokenizer=tokenizer, - data_path=args.pretrain_dataset, - max_datasets_size=args.max_datasets_size, - max_length=args.max_input_len, - ) - if dist.is_initialized() and dist.get_world_size() > 1: - pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) - else: - pretrain_sampler = None - pretrain_dataloader = DataLoader( - pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size - ) - - # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized. - (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( - (actor, actor_optim), (critic, critic_optim), reward_model, initial_model - ) - - # configure trainer - trainer = PPOTrainer( - strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - tokenizer=tokenizer, - kl_coef=args.kl_coef, - ptx_coef=args.ptx_coef, - train_batch_size=args.train_batch_size, - max_length=args.max_seq_len, - use_cache=True, - do_sample=True, - temperature=1.0, - top_k=50, - offload_inference_models=args.strategy != "colossalai_gemini", - ) - - trainer.fit( - num_episodes=args.num_episodes, - num_collect_steps=args.num_collect_steps, - num_update_steps=args.num_update_steps, - prompt_dataloader=prompt_dataloader, - pretrain_dataloader=pretrain_dataloader, - log_dir=args.log_dir, - use_wandb=args.use_wandb, - ) - - if args.lora_rank > 0 and args.merge_lora_weights: - from coati.models.lora import LORA_MANAGER - - # NOTE: set model to eval to merge LoRA weights - LORA_MANAGER.merge_weights = True - actor.eval() - # save model checkpoint after fitting - strategy.save_pretrained(actor, path=args.save_path) - # save optimizer checkpoint on all ranks - if args.need_optim_ckpt: - strategy.save_optimizer( - actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset") - parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") - parser.add_argument("--max_datasets_size", type=int, default=50000) - parser.add_argument( - "--strategy", - choices=["ddp", "colossalai_gemini", "colossalai_zero2"], - default="colossalai_zero2", - help="strategy to use", - ) - parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument("--pretrain", type=str, default=None) - parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"]) - parser.add_argument("--rm_path", type=str, default=None) - parser.add_argument("--rm_pretrain", type=str, default=None) - parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") - parser.add_argument("--need_optim_ckpt", type=bool, default=False) - parser.add_argument("--num_episodes", type=int, default=10) - parser.add_argument("--num_collect_steps", type=int, default=10) - parser.add_argument("--num_update_steps", type=int, default=5) - parser.add_argument("--train_batch_size", type=int, default=8) - parser.add_argument("--ptx_batch_size", type=int, default=1) - parser.add_argument("--experience_batch_size", type=int, default=8) - parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument("--merge_lora_weights", type=bool, default=True) - parser.add_argument("--lr", type=float, default=1e-7) - parser.add_argument("--kl_coef", type=float, default=0.1) - parser.add_argument("--ptx_coef", type=float, default=0.9) - parser.add_argument("--max_input_len", type=int, default=96) - parser.add_argument("--max_seq_len", type=int, default=128) - parser.add_argument("--log_dir", default="logs", type=str) - parser.add_argument("--use_wandb", default=False, action="store_true") - args = parser.parse_args() - main(args) diff --git a/applications/Chat/examples/train_prompts.sh b/applications/Chat/examples/train_prompts.sh deleted file mode 100755 index d04c41601..000000000 --- a/applications/Chat/examples/train_prompts.sh +++ /dev/null @@ -1,25 +0,0 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | - tail -n +2 | - nl -v 0 | - tee /dev/tty | - sort -g -k 2 | - awk '{print $1}' | - head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 2 - -# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2 - -torchrun --standalone --nproc_per_node=2 train_prompts.py \ - --pretrain_dataset /path/to/data.json \ - --prompt_dataset /path/to/data.json \ - --strategy colossalai_zero2 \ - --num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \ - --train_batch_size 2 diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py deleted file mode 100644 index fcdd29b29..000000000 --- a/applications/Chat/examples/train_reward_model.py +++ /dev/null @@ -1,208 +0,0 @@ -import argparse -import warnings - -import torch -import torch.distributed as dist -from coati.dataset import HhRlhfDataset, RmStaticDataset -from coati.models import LogExpLoss, LogSigLoss -from coati.models.bloom import BLOOMRM -from coati.models.gpt import GPTRM -from coati.models.llama import LlamaRM -from coati.models.opt import OPTRM -from coati.trainer import RewardModelTrainer -from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy -from datasets import load_dataset -from torch.optim import Adam -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -from colossalai.nn.optimizer import HybridAdam - - -def train(args): - # configure strategy - if args.strategy == "ddp": - strategy = DDPStrategy() - elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="auto") - elif args.strategy == "colossalai_zero2": - strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") - else: - raise ValueError(f'Unsupported strategy "{args.strategy}"') - - # configure model - if args.lora_rank > 0: - warnings.warn("Lora is not supported yet.") - args.lora_rank = 0 - - with strategy.model_init_context(): - if args.model == "bloom": - model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == "opt": - model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == "gpt2": - model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == "llama": - model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - else: - raise ValueError(f'Unsupported model "{args.model}"') - - model.to(torch.bfloat16).to(torch.cuda.current_device()) - - if args.model_path is not None: - state_dict = torch.load(args.model_path) - model.load_state_dict(state_dict) - - # configure tokenizer - if args.model == "gpt2": - tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "bloom": - tokenizer = BloomTokenizerFast.from_pretrained( - "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer - ) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "opt": - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "llama": - tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer - ) - tokenizer.eos_token = "" - tokenizer.pad_token = tokenizer.unk_token - else: - raise ValueError(f'Unsupported model "{args.model}"') - - # configure optimizer - if args.strategy.startswith("colossalai"): - optim = HybridAdam(model.parameters(), lr=args.lr) - else: - optim = Adam(model.parameters(), lr=args.lr) - - # configure loss function - if args.loss_fn == "log_sig": - loss_fn = LogSigLoss() - elif args.loss_fn == "log_exp": - loss_fn = LogExpLoss() - else: - raise ValueError(f'Unsupported loss function "{args.loss_fn}"') - - # prepare for data and dataset - if args.subset is not None: - data = load_dataset(args.dataset, data_dir=args.subset) - else: - data = load_dataset(args.dataset) - - train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"])))) - eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"])))) - - if args.dataset == "Dahoas/rm-static": - train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len) - eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len) - elif args.dataset == "Anthropic/hh-rlhf": - train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len) - eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len) - else: - raise ValueError(f'Unsupported dataset "{args.dataset}"') - - if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler( - train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size(), - ) - eval_sampler = DistributedSampler( - eval_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size(), - ) - else: - train_sampler = None - eval_sampler = None - - train_dataloader = DataLoader( - train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - pin_memory=True, - ) - - eval_dataloader = DataLoader( - eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True - ) - - lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100) - strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) - model = strategy_dict["model"] - optim = strategy_dict["optimizer"] - lr_scheduler = strategy_dict["lr_scheduler"] - trainer = RewardModelTrainer( - model=model, - strategy=strategy, - optim=optim, - lr_scheduler=lr_scheduler, - loss_fn=loss_fn, - max_epochs=args.max_epochs, - ) - - trainer.fit( - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - log_dir=args.log_dir, - use_wandb=args.use_wandb, - ) - - if args.lora_rank > 0 and args.merge_lora_weights: - from coati.models.lora import LORA_MANAGER - - # NOTE: set model to eval to merge LoRA weights - LORA_MANAGER.merge_weights = True - model.eval() - # save model checkpoint after fitting on only rank0 - state_dict = model.state_dict() - torch.save(state_dict, args.save_path) - # save optimizer checkpoint on all ranks - if args.need_optim_ckpt: - strategy.save_optimizer( - trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2" - ) - parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument("--pretrain", type=str, default=None) - parser.add_argument("--model_path", type=str, default=None) - parser.add_argument("--need_optim_ckpt", type=bool, default=False) - parser.add_argument( - "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static" - ) - parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None) - parser.add_argument("--max_datasets_size", type=int, default=1000000) - parser.add_argument("--save_path", type=str, default="rm_ckpt") - parser.add_argument("--max_epochs", type=int, default=1) - parser.add_argument("--batch_size", type=int, default=1) - parser.add_argument("--max_len", type=int, default=512) - parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument("--merge_lora_weights", type=bool, default=True) - parser.add_argument("--lr", type=float, default=9e-6) - parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) - parser.add_argument("--log_dir", default="logs", type=str) - parser.add_argument("--use_wandb", default=False, action="store_true") - args = parser.parse_args() - train(args) diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh deleted file mode 100755 index c5ebaf708..000000000 --- a/applications/Chat/examples/train_rm.sh +++ /dev/null @@ -1,25 +0,0 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | - tail -n +2 | - nl -v 0 | - tee /dev/tty | - sort -g -k 2 | - awk '{print $1}' | - head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 2 - -torchrun --standalone --nproc_per_node=2 train_reward_model.py \ - --pretrain 'gpt2' \ - --model 'gpt2' \ - --strategy colossalai_zero2 \ - --loss_fn 'log_exp' \ - --dataset 'Anthropic/hh-rlhf' \ - --batch_size 16 \ - --max_epochs 10 diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py deleted file mode 100644 index d00c04809..000000000 --- a/applications/Chat/examples/train_sft.py +++ /dev/null @@ -1,221 +0,0 @@ -import argparse -import math -import warnings - -import torch -import torch.distributed as dist -from coati.dataset import SFTDataset, SupervisedDataset -from coati.models.bloom import BLOOMActor -from coati.models.chatglm import ChatGLMActor -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer -from coati.models.gpt import GPTActor -from coati.models.llama import LlamaActor -from coati.models.opt import OPTActor -from coati.trainer import SFTTrainer -from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy -from datasets import load_dataset -from torch.optim import Adam -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer -from transformers.trainer import get_scheduler - -from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import HybridAdam - - -def train(args): - # configure strategy - if args.strategy == "ddp": - strategy = DDPStrategy() - elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="auto") - elif args.strategy == "colossalai_zero2": - strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") - elif args.strategy == "colossalai_zero2_cpu": - strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") - else: - raise ValueError(f'Unsupported strategy "{args.strategy}"') - - # configure model - if args.lora_rank > 0: - warnings.warn("Lora is not supported yet.") - args.lora_rank = 0 - - with strategy.model_init_context(): - if args.model == "bloom": - model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) - elif args.model == "opt": - model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) - elif args.model == "gpt2": - model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) - elif args.model == "llama": - model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) - elif args.model == "chatglm": - model = ChatGLMActor(pretrained=args.pretrain) - else: - raise ValueError(f'Unsupported model "{args.model}"') - - model.to(torch.bfloat16).to(torch.cuda.current_device()) - - # configure tokenizer - if args.model == "gpt2": - tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "bloom": - tokenizer = BloomTokenizerFast.from_pretrained( - "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer - ) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "opt": - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) - tokenizer.pad_token = tokenizer.eos_token - elif args.model == "llama": - tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer - ) - tokenizer.eos_token = "" - tokenizer.pad_token = tokenizer.unk_token - elif args.model == "chatglm": - tokenizer = ChatGLMTokenizer.from_pretrained( - "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True - ) - else: - raise ValueError(f'Unsupported model "{args.model}"') - - # configure optimizer - if args.strategy.startswith("colossalai"): - optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) - else: - optim = Adam(model.parameters(), lr=args.lr) - - # configure dataset - if args.dataset == "yizhongw/self_instruct": - train_data = load_dataset(args.dataset, "super_natural_instructions", split="train") - eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test") - - if args.max_datasets_size is not None: - train_data = train_data.select(range(min(args.max_datasets_size, len(train_data)))) - eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data)))) - - train_dataset = SFTDataset(train_data, tokenizer, args.max_len) - eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len) - - else: - train_dataset = SupervisedDataset( - tokenizer=tokenizer, - data_path=args.dataset, - max_datasets_size=args.max_datasets_size, - max_length=args.max_len, - ) - eval_dataset = None - - if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler( - train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size(), - ) - if eval_dataset is not None: - eval_sampler = DistributedSampler( - eval_dataset, - shuffle=False, - seed=42, - drop_last=False, - rank=dist.get_rank(), - num_replicas=dist.get_world_size(), - ) - else: - train_sampler = None - eval_sampler = None - - train_dataloader = DataLoader( - train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - pin_memory=True, - ) - if eval_dataset is not None: - eval_dataloader = DataLoader( - eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - pin_memory=True, - ) - else: - eval_dataloader = None - - num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps - max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) - lr_scheduler = get_scheduler( - "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps - ) - strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) - model = strategy_dict["model"] - optim = strategy_dict["optimizer"] - lr_scheduler = strategy_dict["lr_scheduler"] - trainer = SFTTrainer( - model=model, - strategy=strategy, - optim=optim, - lr_scheduler=lr_scheduler, - max_epochs=args.max_epochs, - accumulation_steps=args.accumulation_steps, - ) - - logger = get_dist_logger() - trainer.fit( - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - logger=logger, - log_dir=args.log_dir, - use_wandb=args.use_wandb, - ) - - if args.lora_rank > 0 and args.merge_lora_weights: - from coati.models.lora import LORA_MANAGER - - # NOTE: set model to eval to merge LoRA weights - LORA_MANAGER.merge_weights = True - model.eval() - # save model checkpoint after fitting on only rank0 - strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer) - # save optimizer checkpoint on all ranks - if args.need_optim_ckpt: - strategy.save_optimizer( - trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--strategy", - choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"], - default="colossalai_zero2", - ) - parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument("--pretrain", type=str, default=None) - parser.add_argument("--dataset", type=str, default=None) - parser.add_argument("--max_datasets_size", type=int, default=None) - parser.add_argument("--save_path", type=str, default="output") - parser.add_argument("--need_optim_ckpt", type=bool, default=False) - parser.add_argument("--max_epochs", type=int, default=3) - parser.add_argument("--batch_size", type=int, default=4) - parser.add_argument("--max_len", type=int, default=512) - parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument("--merge_lora_weights", type=bool, default=True) - parser.add_argument("--lr", type=float, default=5e-6) - parser.add_argument("--accumulation_steps", type=int, default=8) - parser.add_argument("--log_dir", default="logs", type=str) - parser.add_argument("--use_wandb", default=False, action="store_true") - parser.add_argument("--grad_checkpoint", default=False, action="store_true") - args = parser.parse_args() - train(args) diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh deleted file mode 100755 index b7d176847..000000000 --- a/applications/Chat/examples/train_sft.sh +++ /dev/null @@ -1,28 +0,0 @@ -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | - tail -n +2 | - nl -v 0 | - tee /dev/tty | - sort -g -k 2 | - awk '{print $1}' | - head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 4 - -torchrun --standalone --nproc_per_node=4 train_sft.py \ - --pretrain "/path/to/LLaMa-7B/" \ - --model 'llama' \ - --strategy colossalai_zero2 \ - --save_path /path/to/Coati-7B \ - --dataset /path/to/data.json \ - --batch_size 4 \ - --accumulation_steps 8 \ - --lr 2e-5 \ - --max_datasets_size 512 \ - --max_epochs 1 \ No newline at end of file diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py deleted file mode 100644 index dbb5490a6..000000000 --- a/applications/Chat/inference/benchmark.py +++ /dev/null @@ -1,141 +0,0 @@ -# Adapted from https://github.com/tloen/alpaca-lora/blob/main/generate.py - -import argparse -from time import time - -import torch -from coati.quant import llama_load_quant, low_resource_init -from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM - - -def generate_prompt(instruction, input=None): - if input: - return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. - -### Instruction: -{instruction} - -### Input: -{input} - -### Response:""" - else: - return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. - -### Instruction: -{instruction} - -### Response:""" - - -@torch.no_grad() -def evaluate( - model, - tokenizer, - instruction, - input=None, - temperature=0.1, - top_p=0.75, - top_k=40, - num_beams=4, - max_new_tokens=128, - **kwargs, -): - prompt = generate_prompt(instruction, input) - inputs = tokenizer(prompt, return_tensors="pt") - input_ids = inputs["input_ids"].cuda() - generation_config = GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - num_beams=num_beams, - **kwargs, - ) - generation_output = model.generate( - input_ids=input_ids, - generation_config=generation_config, - return_dict_in_generate=True, - output_scores=True, - max_new_tokens=max_new_tokens, - do_sample=True, - ) - s = generation_output.sequences[0] - output = tokenizer.decode(s) - n_new_tokens = s.size(0) - input_ids.size(1) - return output.split("### Response:")[1].strip(), n_new_tokens - - -instructions = [ - "Tell me about alpacas.", - "Tell me about the president of Mexico in 2019.", - "Tell me about the king of France in 2019.", - "List all Canadian provinces in alphabetical order.", - "Write a Python program that prints the first 10 Fibonacci numbers.", - "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", - "Tell me five words that rhyme with 'shock'.", - "Translate the sentence 'I have no mouth but I must scream' into Spanish.", - "Count up from 1 to 500.", - # === - "How to play support in legends of league", - "Write a Python program that calculate Fibonacci numbers.", -] -inst = [instructions[0]] * 4 - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "pretrained", - help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", - ) - parser.add_argument( - "--quant", - choices=["8bit", "4bit"], - default=None, - help="Quantization mode. Default: None (no quantization, fp16).", - ) - parser.add_argument( - "--gptq_checkpoint", - default=None, - help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.", - ) - parser.add_argument( - "--gptq_group_size", - type=int, - default=128, - help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.", - ) - args = parser.parse_args() - - if args.quant == "4bit": - assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." - - tokenizer = AutoTokenizer.from_pretrained(args.pretrained) - - if args.quant == "4bit": - with low_resource_init(): - config = LlamaConfig.from_pretrained(args.pretrained) - model = LlamaForCausalLM(config) - model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size) - model.cuda() - else: - model = LlamaForCausalLM.from_pretrained( - args.pretrained, - load_in_8bit=(args.quant == "8bit"), - torch_dtype=torch.float16, - device_map="auto", - ) - if args.quant != "8bit": - model.half() # seems to fix bugs for some users. - model.eval() - - total_tokens = 0 - start = time() - for instruction in instructions: - print(f"Instruction: {instruction}") - resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1) - total_tokens += tokens - print(f"Response: {resp}") - print("\n----------------------------\n") - duration = time() - start - print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s") - print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py deleted file mode 100644 index 9835e7189..000000000 --- a/applications/Chat/inference/tests/test_chat_prompt.py +++ /dev/null @@ -1,61 +0,0 @@ -import os - -from transformers import AutoTokenizer -from utils import ChatPromptProcessor, Dialogue - -CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." -tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"]) - -samples = [ - ( - [ - Dialogue( - instruction="Who is the best player in the history of NBA?", - response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", - ), - Dialogue(instruction="continue this talk", response=""), - ], - 128, - "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n", - ), - ( - [ - Dialogue( - instruction="Who is the best player in the history of NBA?", - response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", - ), - Dialogue(instruction="continue this talk", response=""), - ], - 200, - "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n", - ), - ( - [ - Dialogue( - instruction="Who is the best player in the history of NBA?", - response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", - ), - Dialogue(instruction="continue this talk", response=""), - ], - 211, - "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n", - ), - ( - [ - Dialogue(instruction="Who is the best player in the history of NBA?", response=""), - ], - 128, - "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n", - ), -] - - -def test_chat_prompt_processor(): - processor = ChatPromptProcessor(tokenizer, CONTEXT, 256) - for history, max_new_tokens, result in samples: - prompt = processor.preprocess_prompt(history, max_new_tokens) - assert prompt == result - - -if __name__ == "__main__": - test_chat_prompt_processor() diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py deleted file mode 100644 index af018adf6..000000000 --- a/applications/Chat/inference/utils.py +++ /dev/null @@ -1,209 +0,0 @@ -import json -import re -from threading import Lock -from typing import Any, Callable, Generator, List, Optional - -import jieba -import torch -import torch.distributed as dist -import torch.nn as nn -from pydantic import BaseModel, Field - -try: - from transformers.generation_logits_process import ( - LogitsProcessorList, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, - ) -except ImportError: - from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper - - -def prepare_logits_processor( - top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None -) -> LogitsProcessorList: - processor_list = LogitsProcessorList() - if temperature is not None and temperature != 1.0: - processor_list.append(TemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - processor_list.append(TopKLogitsWarper(top_k)) - if top_p is not None and top_p < 1.0: - processor_list.append(TopPLogitsWarper(top_p)) - return processor_list - - -def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: - if dist.is_initialized() and dist.get_world_size() > 1: - # consider DP - unfinished_sequences = unfinished_sequences.clone() - dist.all_reduce(unfinished_sequences) - return unfinished_sequences.max() == 0 - - -def sample_streamingly( - model: nn.Module, - input_ids: torch.Tensor, - max_generate_tokens: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs, -) -> Generator: - logits_processor = prepare_logits_processor(top_k, top_p, temperature) - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - - for _ in range(max_generate_tokens): - model_inputs = ( - prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids} - ) - outputs = model(**model_inputs) - - next_token_logits = outputs["logits"][:, -1, :] - # pre-process distribution - next_token_logits = logits_processor(input_ids, next_token_logits) - # sample - probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - yield next_tokens - - # update generated ids, model inputs for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if update_model_kwargs_fn is not None: - model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs) - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) - - # stop when each sentence is finished if early_stopping=True - if early_stopping and _is_sequence_finished(unfinished_sequences): - break - - -def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict: - if "past_key_values" in outputs: - model_kwargs["past"] = outputs["past_key_values"] - else: - model_kwargs["past"] = None - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - return model_kwargs - - -class Dialogue(BaseModel): - instruction: str = Field(min_length=1, example="Count up from 1 to 500.") - response: str = Field(example="") - - -def _format_dialogue(instruction: str, response: str = ""): - return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}" - - -STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S)) - - -class ChatPromptProcessor: - SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt." - - def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []): - self.tokenizer = tokenizer - self.context = context - self.max_len = max_len - self.censored_words = set([word.lower() for word in censored_words]) - # These will be initialized after the first call of preprocess_prompt() - self.context_len: Optional[int] = None - self.dialogue_placeholder_len: Optional[int] = None - - def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str: - if self.context_len is None: - self.context_len = len(self.tokenizer(self.context)["input_ids"]) - if self.dialogue_placeholder_len is None: - self.dialogue_placeholder_len = len( - self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"] - ) - prompt = self.context - # the last dialogue must be in the prompt - last_dialogue = history.pop() - # the response of the last dialogue is empty - assert last_dialogue.response == "" - if ( - len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"]) - + max_new_tokens - + self.context_len - >= self.max_len - ): - # to avoid truncate placeholder, apply truncate to the original instruction - instruction_truncated = self.tokenizer( - last_dialogue.instruction, - add_special_tokens=False, - truncation=True, - max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len), - )["input_ids"] - instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip() - prompt += _format_dialogue(instruction_truncated) - return prompt - - res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"]) - - rows = [] - for dialogue in history[::-1]: - text = _format_dialogue(dialogue.instruction, dialogue.response) - cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"]) - if res_len - cur_len < 0: - break - res_len -= cur_len - rows.insert(0, text) - prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction) - return prompt - - def postprocess_output(self, output: str) -> str: - output = STOP_PAT.sub("", output) - return output.strip() - - def has_censored_words(self, text: str) -> bool: - if len(self.censored_words) == 0: - return False - intersection = set(jieba.cut(text.lower())) & self.censored_words - return len(intersection) > 0 - - -class LockedIterator: - def __init__(self, it, lock: Lock) -> None: - self.lock = lock - self.it = iter(it) - - def __iter__(self): - return self - - def __next__(self): - with self.lock: - return next(self.it) - - -def load_json(path: str): - with open(path) as f: - return json.load(f) diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt deleted file mode 100644 index 93d48bcb6..000000000 --- a/applications/Chat/requirements-test.txt +++ /dev/null @@ -1,2 +0,0 @@ -pytest -colossalai==0.3.3 diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt deleted file mode 100644 index e56aaca0e..000000000 --- a/applications/Chat/requirements.txt +++ /dev/null @@ -1,14 +0,0 @@ -transformers>=4.20.1 -tqdm -datasets -loralib -colossalai==0.3.3 -torch<2.0.0, >=1.12.1 -langchain -tokenizers -fastapi -sse_starlette -wandb -sentencepiece -gpustat -tensorboard diff --git a/applications/Chat/tests/test_benchmarks.sh b/applications/Chat/tests/test_benchmarks.sh deleted file mode 100755 index 3fdb25181..000000000 --- a/applications/Chat/tests/test_benchmarks.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -set -xue - -echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies." - -if [[ $# -ne 0 && "$1" == "verbose" ]]; then - STRATEGIES=( - 'ddp' - 'colossalai_gemini' - 'colossalai_gemini_cpu' - 'colossalai_zero2' - 'colossalai_zero2_cpu' - 'colossalai_zero1' - 'colossalai_zero1_cpu' - ) -else - STRATEGIES=( - 'colossalai_zero2' - ) -fi - -BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) -BENCHMARKS_DIR=$BASE_DIR/benchmarks - -echo "[Test]: testing benchmarks ..." - -for strategy in ${STRATEGIES[@]}; do - torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \ - --model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \ - --num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \ - --train_batch_size 2 --experience_batch_size 4 -done diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py deleted file mode 100644 index 9c08aa36c..000000000 --- a/applications/Chat/tests/test_checkpoint.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -import tempfile -from contextlib import nullcontext - -import pytest -import torch -import torch.distributed as dist -from coati.models.gpt import GPTActor -from coati.models.utils import calc_action_log_probs -from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy -from transformers.models.gpt2.configuration_gpt2 import GPT2Config - -from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import rerun_if_address_is_in_use, spawn - -GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) - - -def get_data(batch_size: int, seq_len: int = 10) -> dict: - input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda") - attention_mask = torch.ones_like(input_ids) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8): - data = get_data(batch_size) - action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) - actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"] - action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1)) - loss = action_log_probs.sum() - strategy.backward(loss, actor, actor_optim) - strategy.optimizer_step(actor_optim) - - -def run_test_checkpoint(strategy_name: str, shard: bool): - if strategy_name == "ddp": - strategy = DDPStrategy() - elif strategy_name == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5) - elif strategy_name == "colossalai_zero2": - strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") - else: - raise ValueError(f"Unsupported strategy '{strategy_name}'") - - with strategy.model_init_context(): - actor = GPTActor(config=GPT_CONFIG).cuda() - actor_optim = HybridAdam(actor.parameters()) - actor, actor_optim = strategy.prepare((actor, actor_optim)) - - train_step(strategy, actor, actor_optim) - - ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() - - with ctx as dirname: - rank0_dirname = [dirname] - dist.broadcast_object_list(rank0_dirname) - rank0_dirname = rank0_dirname[0] - - model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt") - strategy.save_model(actor, model_path) - optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt") - strategy.save_optimizer(actor_optim, optim_path) - dist.barrier() - - strategy.load_model(actor, model_path, strict=False) - strategy.load_optimizer(actor_optim, optim_path) - dist.barrier() - - train_step(strategy, actor, actor_optim) - - -def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool): - os.environ["RANK"] = str(rank) - os.environ["LOCAL_RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(port) - run_test_checkpoint(strategy_name, shard) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"]) -@pytest.mark.parametrize("shard", [False, True]) -@rerun_if_address_is_in_use() -def test_checkpoint(world_size: int, strategy_name: str, shard: bool): - spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard) - - -if __name__ == "__main__": - test_checkpoint(2, "colossalai_gemini", shard=False) diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py deleted file mode 100644 index ec61bbb13..000000000 --- a/applications/Chat/tests/test_dataset.py +++ /dev/null @@ -1,241 +0,0 @@ -import json -import os -import tempfile -from typing import Optional - -import pytest -import torch -from coati.dataset.prompt_dataset import PromptDataset -from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset -from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer -from datasets import load_dataset -from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - -SFT_DATASET = [ - { - "instruction": "Provide a list of the top 10 most popular mobile games in Asia", - "input": "", - "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": 0, - }, - { - "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", - "input": "", - "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", - "id": 1, - }, - { - "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", - "input": "", - "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", - "id": 2, - }, -] - -PROMPT_DATASET = [ - { - "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."', - "id": 0, - }, - {"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1}, - {"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2}, - {"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3}, -] - - -def make_tokenizer(model: str): - if model == "gpt2": - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - tokenizer.pad_token = tokenizer.eos_token - elif model == "bloom": - tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") - tokenizer.pad_token = tokenizer.eos_token - elif model == "opt": - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - tokenizer.pad_token = tokenizer.eos_token - elif model == "llama": - tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer.pad_token = tokenizer.unk_token - elif model == "chatglm": - tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) - else: - raise ValueError(f"Unsupported model '{model}'") - return tokenizer - - -def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str): - if model == "opt": - # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt. - assert input_ids_stripped[0] == tokenizer.eos_token_id - input_ids_stripped = input_ids_stripped[1:] - elif model == "llama": - assert input_ids_stripped[0] == tokenizer.bos_token_id - input_ids_stripped = input_ids_stripped[1:] - elif model == "chatglm": - assert input_ids_stripped[0] == tokenizer.bos_token_id - assert input_ids_stripped[-1] == tokenizer.eos_token_id - input_ids_stripped = input_ids_stripped[1:-1] - assert torch.all(input_ids_stripped != tokenizer.pad_token_id) - assert torch.all(input_ids_stripped != tokenizer.bos_token_id) - assert torch.all(input_ids_stripped != tokenizer.eos_token_id) - assert input_ids_stripped != tokenizer.sep_token_id - assert input_ids_stripped != tokenizer.cls_token_id - if model == "chatglm": - assert torch.all(input_ids_stripped != tokenizer.mask_token_id) - else: - assert input_ids_stripped != tokenizer.mask_token_id - - -@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) -@pytest.mark.parametrize("max_length", [32, 1024]) -@pytest.mark.parametrize("max_datasets_size", [2]) -def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): - with tempfile.TemporaryDirectory() as tmp_dir: - dataset_name = "prompt_dataset.json" - with open(os.path.join(tmp_dir, dataset_name), "w") as f: - json.dump(PROMPT_DATASET, f) - tokenizer = make_tokenizer(model) - assert tokenizer.padding_side in ("left", "right") - prompt_dataset = PromptDataset( - data_path=os.path.join(tmp_dir, dataset_name), - tokenizer=tokenizer, - max_datasets_size=max_datasets_size, - max_length=max_length, - ) - assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET)) - for i in range(len(prompt_dataset)): - assert isinstance(prompt_dataset[i], dict) - assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"] - input_ids = prompt_dataset[i]["input_ids"] - attention_mask = prompt_dataset[i]["attention_mask"] - attention_mask = attention_mask.bool() - assert input_ids.shape == attention_mask.shape == torch.Size([max_length]) - assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id) - check_content(input_ids.masked_select(attention_mask), tokenizer, model) - - -@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) -@pytest.mark.parametrize( - ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)] -) -@pytest.mark.parametrize("max_datasets_size", [32]) -@pytest.mark.parametrize("max_length", [32, 1024]) -def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int): - data = load_dataset(dataset_path, data_dir=subset) - assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"]) - train_data = data["train"].select(range(max_datasets_size)) - test_data = data["test"].select(range(max_datasets_size)) - tokenizer = make_tokenizer(model) - assert tokenizer.padding_side in ("left", "right") - - if dataset_path == "Anthropic/hh-rlhf": - train_dataset = HhRlhfDataset(train_data, tokenizer, max_length) - test_dataset = HhRlhfDataset(test_data, tokenizer, max_length) - elif dataset_path == "Dahoas/rm-static": - train_dataset = RmStaticDataset(train_data, tokenizer, max_length) - test_dataset = RmStaticDataset(test_data, tokenizer, max_length) - else: - raise ValueError(f'Unsupported dataset "{dataset_path}"') - - assert len(train_dataset) == len(test_dataset) == max_datasets_size - for i in range(max_datasets_size): - chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i] - assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length]) - c_mask = c_mask.to(torch.bool) - r_mask = r_mask.to(torch.bool) - if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: - check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model) - assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id) - else: - check_content(chosen_ids.masked_select(c_mask), tokenizer, model) - assert torch.all(c_mask) - if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id: - check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model) - assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id) - else: - check_content(reject_ids.masked_select(r_mask), tokenizer, model) - assert torch.all(r_mask) - - chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i] - assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length]) - c_mask = c_mask.to(torch.bool) - r_mask = r_mask.to(torch.bool) - if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: - check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model) - assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id) - else: - check_content(chosen_ids.masked_select(c_mask), tokenizer, model) - assert torch.all(c_mask) - if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id: - check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model) - assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id) - else: - check_content(reject_ids.masked_select(r_mask), tokenizer, model) - assert torch.all(r_mask) - - -@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"]) -@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) -@pytest.mark.parametrize("max_dataset_size", [2]) -@pytest.mark.parametrize("max_length", [32, 1024]) -def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int): - tokenizer = make_tokenizer(model) - if dataset_path == "yizhongw/self_instruct": - data = load_dataset(dataset_path, "super_natural_instructions") - train_data = data["train"].select(range(max_dataset_size)) - sft_dataset = SFTDataset(train_data, tokenizer, max_length) - else: - with tempfile.TemporaryDirectory() as tmp_dir: - dataset_name = "sft_dataset.json" - with open(os.path.join(tmp_dir, dataset_name), "w") as f: - json.dump(SFT_DATASET, f) - sft_dataset = SupervisedDataset( - tokenizer=tokenizer, - data_path=os.path.join(tmp_dir, dataset_name), - max_datasets_size=max_dataset_size, - max_length=max_length, - ) - assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET)) - - if isinstance(tokenizer, ChatGLMTokenizer): - for i in range(max_dataset_size): - assert isinstance(sft_dataset[i], dict) - assert list(sft_dataset[i].keys()) == ["input_ids", "labels"] - input_ids = sft_dataset[i]["input_ids"] - labels = sft_dataset[i]["labels"] - assert input_ids.shape == labels.shape == torch.Size([max_length]) - - ignore_mask = labels == IGNORE_INDEX - assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id - check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model) - return - - for i in range(max_dataset_size): - assert isinstance(sft_dataset[i], dict) - assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"] - input_ids = sft_dataset[i]["input_ids"] - labels = sft_dataset[i]["labels"] - attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool) - assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length]) - if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id: - check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model) - assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id) - else: - check_content(input_ids.masked_select(attention_mask), tokenizer, model) - assert torch.all(attention_mask) - ignore_mask = labels == IGNORE_INDEX - prompt_mask = torch.logical_and(ignore_mask, attention_mask) - check_content(input_ids.masked_select(prompt_mask), tokenizer, model) - assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id) - - -if __name__ == "__main__": - test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256) - - test_reward_dataset( - model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256 - ) - - test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128) diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py deleted file mode 100644 index a95912598..000000000 --- a/applications/Chat/tests/test_experience.py +++ /dev/null @@ -1,130 +0,0 @@ -import copy -import os - -import pytest -import torch -import torch.distributed as dist -from coati.experience_buffer import NaiveExperienceBuffer -from coati.experience_maker import NaiveExperienceMaker -from coati.models.base import RewardModel -from coati.models.gpt import GPTActor, GPTCritic -from coati.trainer.ppo import _set_default_generate_kwargs -from coati.trainer.strategies import DDPStrategy, GeminiStrategy -from coati.trainer.strategies.colossalai import LowLevelZeroStrategy -from transformers.models.gpt2.configuration_gpt2 import GPT2Config - -from colossalai.testing import rerun_if_address_is_in_use, spawn - -GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) - - -def get_data(batch_size: int, seq_len: int = 10) -> dict: - input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda") - attention_mask = torch.ones_like(input_ids) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - -def gather_and_equal(tensor: torch.Tensor) -> bool: - world_size = dist.get_world_size() - outputs = [torch.empty_like(tensor) for _ in range(world_size)] - dist.all_gather(outputs, tensor.contiguous()) - for t in outputs[1:]: - if not torch.equal(outputs[0], t): - return False - return True - - -def make_and_consume_experience(strategy): - EXPERIENCE_BATCH_SIZE = 4 - SAMPLE_BATCH_SIZE = 2 - - if strategy == "ddp": - strategy = DDPStrategy() - elif strategy == "colossalai-zero2": - strategy = LowLevelZeroStrategy() - elif strategy == "colossalai-gemini": - strategy = GeminiStrategy(placement_policy="static") - else: - raise ValueError(f'Unsupported strategy "{strategy}"') - - with strategy.model_init_context(): - actor = GPTActor(config=GPT_CONFIG).cuda() - critic = GPTCritic(config=GPT_CONFIG).cuda() - - initial_model = GPTActor(config=GPT_CONFIG).cuda() - reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda() - - actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model) - - class MockTokenizer: - def __init__(self): - self.padding_side = "left" - self.eos_token_id = 0 - self.pad_token_id = 0 - - tokenizer = MockTokenizer() - experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer) - data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) - - generate_kwargs = dict(do_sample=True, max_length=16) - generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) - - # experience of all ranks should be the same - for _ in range(2): - data = get_data(EXPERIENCE_BATCH_SIZE) - assert gather_and_equal(data["input_ids"]) - assert gather_and_equal(data["attention_mask"]) - experience = experience_maker.make_experience(**data, do_sample=True, max_length=16) - assert gather_and_equal(experience.sequences) - assert gather_and_equal(experience.action_log_probs) - assert gather_and_equal(experience.values) - assert gather_and_equal(experience.reward) - assert gather_and_equal(experience.advantages) - assert gather_and_equal(experience.action_mask) - assert gather_and_equal(experience.attention_mask) - data_buffer.append(experience) - - # data buffer's data should be the same - buffer_size = torch.tensor([len(data_buffer)], device="cuda") - assert gather_and_equal(buffer_size) - for item in data_buffer.items: - assert gather_and_equal(item.sequences) - assert gather_and_equal(item.action_log_probs) - assert gather_and_equal(item.values) - assert gather_and_equal(item.reward) - assert gather_and_equal(item.advantages) - assert gather_and_equal(item.action_mask) - assert gather_and_equal(item.attention_mask) - - # dataloader of each rank should have the same size and different batch - dataloader = strategy.setup_dataloader(data_buffer) - dataloader_size = torch.tensor([len(dataloader)], device="cuda") - assert gather_and_equal(dataloader_size) - for experience in dataloader: - assert not gather_and_equal(experience.sequences) - assert not gather_and_equal(experience.action_log_probs) - assert not gather_and_equal(experience.values) - assert not gather_and_equal(experience.reward) - assert not gather_and_equal(experience.advantages) - # action mask and attention mask may be same - - -def run_dist(rank, world_size, port, strategy): - os.environ["RANK"] = str(rank) - os.environ["LOCAL_RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(port) - make_and_consume_experience(strategy) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"]) -@rerun_if_address_is_in_use() -def test_experience(world_size, strategy): - spawn(run_dist, world_size, strategy=strategy) - - -if __name__ == "__main__": - test_experience(2, "colossalai-zero2") diff --git a/applications/Chat/tests/test_inference.sh b/applications/Chat/tests/test_inference.sh deleted file mode 100755 index 849db06e5..000000000 --- a/applications/Chat/tests/test_inference.sh +++ /dev/null @@ -1,11 +0,0 @@ -set -xue - -BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) -EXAMPLES_DIR=$BASE_DIR/examples - -echo "[Test]: testing inference ..." - -# HACK: skip llama due to oom -for model in 'gpt2' 'bloom' 'opt'; do - python $EXAMPLES_DIR/inference.py --model $model -done diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py deleted file mode 100644 index b2c22ac6a..000000000 --- a/applications/Chat/tests/test_models.py +++ /dev/null @@ -1,245 +0,0 @@ -import copy -from typing import Any, Callable, Dict, Tuple - -import pytest -import torch -import torch.nn as nn -from coati.models.base import Actor, Critic, RewardModel, get_base_model -from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic -from coati.models.chatglm import ChatGLMActor -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer -from coati.models.generation import generate -from coati.models.gpt import GPTRM, GPTActor, GPTCritic -from coati.models.llama import LlamaActor -from coati.models.lora import LoraLinear, convert_to_lora_module -from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss -from coati.models.opt import OPTRM, OPTActor, OPTCritic -from coati.models.utils import calc_action_log_probs, masked_mean - - -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seq_len", [32]) -@pytest.mark.parametrize( - "actor_maker", - [ - lambda: BLOOMActor(), - lambda: GPTActor(), - # HACK: skip llama due to long execution time - # lambda: LlamaActor(), - lambda: OPTActor(), - ], -) -@pytest.mark.parametrize( - "generate_kwargs", - [ - { - "max_length": 64, - "use_cache": True, - "do_sample": True, - "temperature": 1.0, - "top_k": 50, - } - ], -) -def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): - class MockTokenizer: - def __init__(self): - self.padding_side = "left" - self.eos_token_id = 0 - self.pad_token_id = 0 - - actor = actor_maker() - input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() - tokenizer = MockTokenizer() - sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs) - assert sequences.shape == (batch_size, generate_kwargs["max_length"]) - - -def test_utils(): - fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))} - fn_output = masked_mean(dim=0, **fn_input) - assert fn_output.dim() == 0 - assert torch.allclose(fn_output, torch.tensor(1.0)) - - batch_size = 4 - seq_len = 32 - num_labels = 10 - num_actions = 2 - fn_input = { - "logits": torch.randn((batch_size, seq_len, num_labels)), - "sequences": torch.randint(0, num_labels, (batch_size, seq_len)), - "num_actions": num_actions, - } - fn_output = calc_action_log_probs(**fn_input) - assert fn_output.shape == (batch_size, num_actions) - - -@pytest.mark.parametrize("lora_rank", [4]) -@pytest.mark.parametrize("num_dim", [32]) -@pytest.mark.parametrize("num_layers", [4]) -def test_lora(lora_rank: int, num_dim: int, num_layers: int): - model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)]) - lora_model = convert_to_lora_module(model, lora_rank) - assert isinstance(lora_model, nn.ModuleList) - for i in range(num_layers): - assert isinstance(lora_model[i], LoraLinear) - assert lora_model[i].lora_A.shape == (lora_rank, num_dim) - assert lora_model[i].lora_B.shape == (num_dim, lora_rank) - - old_model = copy.deepcopy(lora_model) - for i in range(num_layers): - assert isinstance(lora_model[i], LoraLinear) - assert torch.allclose(old_model[i].weight, lora_model[i].weight) - assert torch.allclose(old_model[i].bias, lora_model[i].bias) - assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A) - optimizer = torch.optim.Adam(lora_model.parameters()) - x = torch.randn(8, num_dim) - for i in range(num_layers): - x = lora_model[i](x) - loss = x.sum() - loss.backward() - optimizer.step() - for i in range(num_layers): - assert isinstance(lora_model[i], LoraLinear) - assert torch.allclose(old_model[i].weight, lora_model[i].weight) - assert torch.allclose(old_model[i].bias, lora_model[i].bias) - assert not torch.allclose( - old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A - ) - - -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("seq_len", [128]) -@pytest.mark.parametrize( - "models_maker", - [ - lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), - lambda: (GPTActor(), GPTCritic(), GPTRM()), - # HACK: skip llama due to long execution time - # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), - lambda: (OPTActor(), OPTCritic(), OPTRM()), - lambda: (ChatGLMActor(), None, None), - ], -) -@torch.no_grad() -def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int): - actor_input = { - "input_ids": torch.randint(0, 100, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), - } - critic_input = { - "sequences": torch.randint(0, 100, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), - } - rm_input = { - "sequences": torch.randint(0, 100, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), - } - - actor, critic, rm = models_maker() - if isinstance(actor, ChatGLMActor): - actor = actor.float() - tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) - chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1) - actor_input = { - "input_ids": torch.cat( - ( - torch.randint(0, 100, (batch_size, seq_len // 2)), - chatglm_special_token, - torch.randint(0, 100, (batch_size, seq_len // 2 - 2)), - ), - dim=1, - ), - "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)), - } - assert isinstance(actor, Actor) - get_base_model(actor) - actor_output = actor(**actor_input) - assert actor_output.logits.shape[:2] == (batch_size, seq_len) - - if critic: - assert isinstance(critic, Critic) - get_base_model(critic) - critic_output = critic(**critic_input) - assert critic_output.shape == (batch_size,) - - if rm: - assert isinstance(rm, RewardModel) - get_base_model(rm) - rm_output = rm(**rm_input) - assert rm_output.shape == (batch_size,) - - -@pytest.mark.parametrize("batch_size", [16]) -@pytest.mark.parametrize("seq_len", [128]) -@pytest.mark.parametrize("num_labels", [100]) -def test_loss(batch_size: int, seq_len: int, num_labels: int): - loss = GPTLMLoss() - loss_input = { - "logits": torch.randn(batch_size, seq_len, num_labels), - "labels": torch.randint(0, num_labels, (batch_size, seq_len)), - } - loss(**loss_input) - - loss = PolicyLoss() - loss_input = { - "log_probs": torch.randn( - batch_size, - ), - "old_log_probs": torch.randn( - batch_size, - ), - "advantages": torch.randn( - batch_size, - ), - } - loss(**loss_input) - - loss = ValueLoss() - loss_input = { - "values": torch.randn( - batch_size, - ), - "old_values": torch.randn( - batch_size, - ), - "reward": torch.randn( - batch_size, - ), - } - loss(**loss_input) - - loss = LogSigLoss() - loss_input = { - "chosen_reward": torch.randn( - batch_size, - ), - "reject_reward": torch.randn( - batch_size, - ), - } - loss(**loss_input) - - loss = LogExpLoss() - loss_input = { - "chosen_reward": torch.randn( - batch_size, - ), - "reject_reward": torch.randn( - batch_size, - ), - } - loss(**loss_input) - - -if __name__ == "__main__": - generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50) - test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs) - - test_utils() - - test_lora(lora_rank=2, num_dim=8, num_layers=2) - - test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) - - test_loss(batch_size=8, seq_len=128, num_labels=100) diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh deleted file mode 100755 index 68fca7fbf..000000000 --- a/applications/Chat/tests/test_train.sh +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env bash - -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | - tail -n +2 | - nl -v 0 | - tee /dev/tty | - sort -g -k 2 | - awk '{print $1}' | - head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES 4 - -set -xu - -if [ -z "$SFT_DATASET" ]; then - echo "Please set \$SFT_DATASET to the path to sft dataset." - exit 1 -fi - -if [ -z "$PROMPT_DATASET" ]; then - echo "Please set \$PROMPT_DATASET to the path to prompts csv." - exit 1 -fi - -if [ -z "$PRETRAIN_DATASET" ]; then - echo "Please set \$PRETRAIN_DATASET to the path to alpaca data." - exit 1 -fi - -NUM_RETRY=3 -BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) -EXAMPLES_DIR=$BASE_DIR/examples -MODELS_DIR=$BASE_DIR/examples/models_config -MODELS=('gpt2' 'bloom' 'opt' 'llama') -STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2') - - -export OMP_NUM_THREADS=8 - -# install requirements -pip install -r $EXAMPLES_DIR/requirements.txt - -python $EXAMPLES_DIR/download_model.py --model-dir $MODELS_DIR --config-only - -get_pretrain() { - local model=$1 - if [[ $model == "gpt2" ]]; then - echo "gpt2" - elif [[ $model == "bloom" ]]; then - echo "bigscience/bloom-560m" - elif [[ $model == "opt" ]]; then - echo "facebook/opt-350m" - else - echo "Unknown model $model" - exit 1 - fi -} - -random_choice() { - local arr=("$@") - local len=${#arr[@]} - local idx=$((RANDOM % len)) - echo ${arr[$idx]} -} - -echo "[Test]: testing sft ..." - -# FIXME: This is a hack to skip tests that are not working -# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation -# - llama-*: These tests can be passed locally, skipped for long execution time -# - *-gemini: Gemini plugin does not support `from_pretrained` yet -SKIPPED_TESTS=( - "gpt2-ddp" - "llama-ddp" - "llama-colossalai_gemini" - "llama-colossalai_zero2" -) - -GRAD_CKPTS=('' '--grad_checkpoint') -for lora_rank in '0'; do - for model in ${MODELS[@]}; do - strategies=($(shuf -e "${STRATEGIES[@]}")) - for strategy in ${strategies[@]}; do - if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then - echo "[Test]: Skipped $model-$strategy-$lora_rank" - continue - elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then - echo "[Test]: Skipped $model-$strategy" - continue - fi - pretrain=$(get_pretrain $model) - pretrain_model="" - if [[ $lora_rank -gt 0 ]]; then - pretrain_model="--pretrain $pretrain" - fi - grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") - for i in $(seq $NUM_RETRY); do - echo "[Test]: $model-$strategy-$lora_rank, attempt $i" - torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_sft.py \ - $pretrain_model --tokenizer $MODELS_DIR/$model \ - --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \ - --dataset $SFT_DATASET --max_datasets_size 8 \ - --max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \ - --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} - passed=$? - if [ $passed -eq 0 ]; then - break - fi - done - if [ $passed -ne 0 ]; then - echo "[Test]: Failed $model-$strategy-$lora_rank" - exit 1 - fi - done - done -done - -echo "[Test]: testing reward model ..." - -# FIXME: This is a hack to skip tests that are not working -# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation -# - llama-*: These tests can be passed locally, skipped for long execution time -# - *-gemini: Gemini plugin does not support `from_pretrained` yet -SKIPPED_TESTS=( - "gpt2-ddp" - "llama-ddp" - "llama-colossalai_gemini" - "llama-colossalai_zero2" -) - -LOSS_FNS=('log_sig' 'log_exp') -DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static') -for lora_rank in '0'; do - for model in ${MODELS[@]}; do - strategies=($(shuf -e "${STRATEGIES[@]}")) - for strategy in ${strategies[@]}; do - if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then - echo "[Test]: Skipped $model-$strategy-$lora_rank" - continue - elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then - echo "[Test]: Skipped $model-$strategy" - continue - fi - pretrain=$(get_pretrain $model) - pretrain_model="" - if [[ $lora_rank -gt 0 ]]; then - pretrain_model="--pretrain $pretrain" - fi - loss_fn=$(random_choice "${LOSS_FNS[@]}") - dataset=$(random_choice "${DATASETS[@]}") - subset=$(if [[ $dataset == "Dahoas/rm-static" ]]; then echo "None"; else echo "harmless-base"; fi) - for i in $(seq $NUM_RETRY); do - echo "[Test]: $model-$strategy-$lora_rank, attempt $i" - torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \ - $pretrain_model --tokenizer $MODELS_DIR/$model \ - --dataset $dataset --subset $subset --max_datasets_size 8 \ - --model $model --strategy $strategy --lora_rank $lora_rank \ - --loss_fn $loss_fn --batch_size 1 --lr 1e-8 \ - --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt - passed=$? - if [ $passed -eq 0 ]; then - break - fi - done - if [ $passed -ne 0 ]; then - echo "[Test]: Failed to train reward model $model-$strategy-$lora_rank" - exit 1 - fi - done - done -done - -echo "[Test]: testing RLHF ..." - -# FIXME: This is a hack to skip tests that are not working -# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation -# - llama-*: These tests can be passed locally, skipped for long execution time -# - *-gemini: Gemini plugin does not support `from_pretrained` yet -SKIPPED_TESTS=( - "gpt2-ddp" - "llama-ddp" - "llama-colossalai_gemini" - "llama-colossalai_zero2" -) - -for model in ${MODELS[@]}; do - for lora_rank in '0'; do - strategies=($(shuf -e "${STRATEGIES[@]}")) - for strategy in ${strategies[@]}; do - if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then - echo "[Test]: Skipped $model-$strategy-$lora_rank" - continue - elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then - echo "[Test]: Skipped $model-$strategy" - continue - fi - rm_pretrain=$(get_pretrain $model) - rm_pretrain_model="" - if [[ $lora_rank -gt 0 ]]; then - rm_pretrain_model="--rm_pretrain $rm_pretrain" - fi - for i in $(seq $NUM_RETRY); do - echo "[Test]: $model-$strategy-$lora_rank, attempt $i" - torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \ - --prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \ - --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \ - --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \ - --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \ - --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \ - $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \ - --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts - passed=$? - if [ $passed -eq 0 ]; then - break - fi - done - if [ $passed -ne 0 ]; then - echo "[Test]: Failed to train RLHF $model-$strategy-$lora_rank" - exit 1 - fi - done - rm -rf $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} - rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt - done -done -rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts diff --git a/applications/Chat/.gitignore b/applications/ColossalChat/.gitignore old mode 100644 new mode 100755 similarity index 87% rename from applications/Chat/.gitignore rename to applications/ColossalChat/.gitignore index 5fa068105..33950adc0 --- a/applications/Chat/.gitignore +++ b/applications/ColossalChat/.gitignore @@ -143,6 +143,17 @@ docs/.build *.pt # wandb log -example/wandb/ +examples/wandb/ +examples/logs/ +examples/output/ examples/awesome-chatgpt-prompts/ +temp/ + +# ColossalChat +applications/ColossalChat/logs +applications/ColossalChat/models +applications/ColossalChat/sft_data +applications/ColossalChat/prompt_data +applications/ColossalChat/preference_data +applications/ColossalChat/temp diff --git a/applications/Chat/LICENSE b/applications/ColossalChat/LICENSE old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/LICENSE rename to applications/ColossalChat/LICENSE diff --git a/applications/Chat/README.md b/applications/ColossalChat/README.md old mode 100644 new mode 100755 similarity index 59% rename from applications/Chat/README.md rename to applications/ColossalChat/README.md index 349c26aad..769f0b3d0 --- a/applications/Chat/README.md +++ b/applications/ColossalChat/README.md @@ -13,10 +13,10 @@ - [Install the environment](#install-the-environment) - [Install the Transformers](#install-the-transformers) - [How to use?](#how-to-use) - - [Supervised datasets collection](#supervised-datasets-collection) - - [RLHF Training Stage1 - Supervised instructs tuning](#RLHF-training-stage1---supervised-instructs-tuning) - - [RLHF Training Stage2 - Training reward model](#RLHF-training-stage2---training-reward-model) - - [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#RLHF-training-stage3---training-model-with-reinforcement-learning-by-human-feedback) + - [Supervised datasets collection](#step-1-data-collection) + - [RLHF Training Stage1 - Supervised instructs tuning](#rlhf-training-stage1---supervised-instructs-tuning) + - [RLHF Training Stage2 - Training reward model](#rlhf-training-stage2---training-reward-model) + - [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#rlhf-training-stage3---proximal-policy-optimization) - [Inference Quantization and Serving - After Training](#inference-quantization-and-serving---after-training) - [Coati7B examples](#coati7b-examples) - [Generation](#generation) @@ -36,7 +36,7 @@ --- -## What is ColossalChat and Coati ? +## What Is ColossalChat And Coati ? [ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project. @@ -91,107 +91,191 @@ More details can be found in the latest news. ## Install -### Install the environment +### Install the Environment ```bash -conda create -n coati -conda activate coati +# Create new environment +conda create -n colossal-chat python=3.10.9 (>=3.8.7) +conda activate colossal-chat + +# Install flash-attention +git clone -b v2.0.5 https://github.com/Dao-AILab/flash-attention.git +cd $FLASH_ATTENTION_ROOT/ +pip install . +cd $FLASH_ATTENTION_ROOT/csrc/xentropy +pip install . +cd $FLASH_ATTENTION_ROOT/csrc/layer_norm +pip install . +cd $FLASH_ATTENTION_ROOT/csrc/rotary +pip install . + +# Clone Colossalai git clone https://github.com/hpcaitech/ColossalAI.git -cd ColossalAI/applications/Chat + +# Install ColossalAI +cd $COLOSSAL_AI_ROOT +BUILD_EXT=1 pip install . + +# Install ColossalChat +cd $COLOSSAL_AI_ROOT/applications/Chat pip install . ``` -### Install the Transformers +## How To Use? -```bash -pip install transformers==4.30.2 -``` +### RLHF Training Stage1 - Supervised Instructs Tuning -## How to use? +Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat. More details can be found in [example guideline](./examples/README.md). -### Supervised datasets collection - -We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo -[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md). - -Here is how we collected the data - -

                  - -

                  - -### RLHF Training Stage1 - Supervised instructs tuning - -Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model. - -You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning. -[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg) - -**Note**: the supervised dataset follows the following format, +#### Step 1: Data Collection +The first step in Stage 1 is to collect a dataset of human demonstrations of the following format. ```json [ - { - "instruction": "Provide a list of the top 10 most popular mobile games in Asia", - "input": "", - "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": 0 + {"messages": + [ + { + "from": "human", + "content": "what are some pranks with a pen i can do?" + }, + { + "from": "assistant", + "content": "Are you looking for practical joke ideas?" + }, + ... + ] }, ... ] ``` -### RLHF Training Stage2 - Training reward model +#### Step 2: Preprocessing +Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization. -Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model +In this code, we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the [example guideline](./examples/README.md) on how to format and tokenize data. -You can run the `examples/train_rm.sh` to start a reward model training. -[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo) +#### Step 3: Training +Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. More details can be found in [example guideline](./examples/README.md). -### RLHF Training Stage3 - Training model with reinforcement learning by human feedback +### RLHF Training Stage2 - Training Reward Model -Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process: +Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model. + +#### Step 1: Data Collection +Below shows the preference dataset format used in training the reward model. + +```json +[ + {"context": [ + { + "from": "human", + "content": "Introduce butterflies species in Oregon." + } + ] + "chosen": [ + { + "from": "assistant", + "content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..." + }, + ... + ], + "rejected": [ + { + "from": "assistant", + "content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..." + }, + ... + ] + }, + ... +] +``` + +#### Step 2: Preprocessing +Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training. + +#### Step 3: Training +You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. More details can be found in [example guideline](./examples/README.md). + +### RLHF Training Stage3 - Proximal Policy Optimization + +In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:

                  -You can run the `examples/train_prompts.sh` to start training PPO with human feedback. -[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g) +#### Step 1: Data Collection +PPO uses two kind of training data--- the prompt data and the sft data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format. -**Note**: the required datasets follow the following format, +```json +[ + {"messages": + [ + { + "from": "human", + "content": "what are some pranks with a pen i can do?" + } + ... + ] + }, +] +``` -- `pretrain dataset` +#### Step 2: Data Preprocessing +To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh) - ```json - [ - { - "instruction": "Provide a list of the top 10 most popular mobile games in Asia", - "input": "", - "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": 0 - }, - ... - ] - ``` +#### Step 3: Training +You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. More detais can be found in [example guideline](./examples/README.md). -- `prompt dataset` +```bash +--pretrain $PRETRAINED_MODEL_PATH \ +--rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectual +--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ +--rm_checkpoint_path $REWARD_MODEL_PATH \ # reward model checkpoint path +--prompt_dataset ${prompt_dataset[@]} \ # List of string, the prompt dataset +--ptx_dataset ${ptx_dataset[@]} \ # List of string, the SFT data used in the SFT stage +--ptx_batch_size 1 \ # batch size for calculate ptx loss +--ptx_coef 0.0 \ # none-zero if ptx loss is enable +--num_episodes 2000 \ # number of episodes to train +--num_collect_steps 1 \ +--num_update_steps 1 \ +--experience_batch_size 8 \ +--train_batch_size 4 \ +--accumulation_steps 2 +``` - ```json - [ - { - "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", - "id": 0 - }, - { - "instruction": "Write a descriptive paragraph about a memorable vacation you went on", - "id": 1 - }, - ... - ] - ``` +Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic. -For more details, see [`examples/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples). +- Without tensor parallelism, +``` +experience buffer size += num_process * num_collect_steps * experience_batch_size += train_batch_size * accumulation_steps * num_process +``` + +- With tensor parallelism, +``` +num_tp_group = num_process / tp +experience buffer size += num_tp_group * num_collect_steps * experience_batch_size += train_batch_size * accumulation_steps * num_tp_group +``` + +## Alternative Option For RLHF: Direct Preference Optimization + +For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. + +### DPO Training Stage1 - Supervised Instructs Tuning + +Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part. + +### DPO Training Stage2 - DPO Training +#### Step 1: Data Collection & Preparation +For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training. + +#### Step 2: Training +You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md). ### Inference Quantization and Serving - After Training @@ -301,91 +385,60 @@ You can find more examples in this [repo](https://github.com/XueFuzhao/Instructi We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format. +- Option 1: Save the model weights, model config and generation config (Note: tokenizer will not be saved) which can be loaded using HF's from_pretrained method. ```python -from coati.models.llama import LlamaLM -from coati.trainer import SFTTrainer +# if use lora, you can choose to merge lora weights before saving +if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER -model = LlamaLM(pretrained=args.pretrain) -tokenizer = AutoTokenizer.from_pretrained(args.pretrain) + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() +# save model checkpoint after fitting on only rank0 +booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) -(model, optim) = strategy.prepare((model, optim)) -trainer = SFTTrainer(model=model, - strategy=strategy, - optim=optim, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - batch_size=args.batch_size, - max_epochs=args.max_epochs, - accumulation_steps=args.accumulation_steps - ) - -trainer.fit() -# this saves in pytorch format -strategy.save_model(model, args.save_path, only_rank0=True) - -# this saves in HF format -strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=tokenizer) ``` +- Option 2: Save the model weights, model config, generation config, as well as the optimizer, learning rate scheduler, running states (Note: tokenizer will not be saved) which are needed for resuming training. +```python +from coati.utils import save_checkpoint +# save model checkpoint after fitting on only rank0 +save_checkpoint( + save_dir=actor_save_dir, + booster=actor_booster, + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + epoch=0, + step=step, + batch_size=train_batch_size, + coordinator=coordinator, + ) +``` +To load the saved checkpoint +```python +from coati.utils import load_checkpoint +start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=checkpoint_path, + booster=booster, + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + ) +```
                  How to train with limited resources -Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs. +Here are some suggestions that can allow you to train a 7B model on a single or multiple consumer-grade GPUs. -If you only have a single 24G GPU, you can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model. +`batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model. To maintain a descent batch size for gradient calculation, consider increase the accumulation_step and reduce the batch_size on each rank. -```bash -// [INFO]: MAX GPU MEMORY ALLOCATED: 19148.9345703125 MB -torchrun --standalone --nproc_per_node=1 train_sft.py \ - --pretrain "/path/to/LLaMa-7B/" \ - --model 'llama' \ - --strategy ddp \ - --save_path /path/to/Coati-7B \ - --dataset /path/to/data.json \ - --batch_size 1 \ - --accumulation_steps 8 \ - --lr 2e-5 \ - --max_datasets_size 512 \ - --max_epochs 1 \ - --lora_rank 16 \ - --grad_checkpoint -``` +If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be sufficient. -`colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script. - -```bash -torchrun --standalone --nproc_per_node=1 train_sft.py \ - --pretrain "/path/to/LLaMa-7B/" \ - --model 'llama' \ - --strategy colossalai_gemini \ - --save_path /path/to/Coati-7B \ - --dataset /path/to/data.json \ - --batch_size 1 \ - --accumulation_steps 8 \ - --lr 2e-5 \ - --max_datasets_size 512 \ - --max_epochs 1 \ - --grad_checkpoint -``` - -If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows. - -```bash -torchrun --standalone --nproc_per_node=4 train_sft.py \ - --pretrain "/path/to/LLaMa-7B/" \ - --model 'llama' \ - --strategy colossalai_zero2_cpu \ - --save_path /path/to/Coati-7B \ - --dataset /path/to/data.json \ - --batch_size 1 \ - --accumulation_steps 8 \ - --lr 2e-5 \ - --max_datasets_size 512 \ - --max_epochs 1 \ - --grad_checkpoint -``` +`gemini` and `gemini-auto` can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. But that strategy doesn't support gradient accumulation. +If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
                  ## The Plan @@ -396,6 +449,8 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ - [x] support inference - [x] support llama from [facebook](https://github.com/facebookresearch/llama) - [x] implement PPO-ptx fine-tuning +- [x] support flash-attention +- [x] implement DPO fine-tuning - [ ] integrate with Ray - [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL), - [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain) @@ -467,6 +522,7 @@ Coati is developed by ColossalAI Team: - [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT. - [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development. - [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements. +- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored version with updated acceleration framework, LoRA, DPO and PPO. The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. - [Zangwei Zheng](https://github.com/zhengzangw) diff --git a/applications/ColossalChat/benchmarks/Opt.json b/applications/ColossalChat/benchmarks/Opt.json new file mode 100644 index 000000000..6d47666bb --- /dev/null +++ b/applications/ColossalChat/benchmarks/Opt.json @@ -0,0 +1,17 @@ +{ + "chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "human_line_start": [ + 2 + ], + "human_line_end": [ + 2 + ], + "assistant_line_start": [ + 2 + ], + "assistant_line_end": [ + 2 + ], + "end_of_system_line_position": 0 +} diff --git a/applications/ColossalChat/benchmarks/README.md b/applications/ColossalChat/benchmarks/README.md new file mode 100755 index 000000000..08c5e0e6c --- /dev/null +++ b/applications/ColossalChat/benchmarks/README.md @@ -0,0 +1,37 @@ +# Benchmarks + +## Benchmark OPT with LoRA on dummy prompt data + +We provide various OPT models (string in parentheses is the corresponding model name used in this script): + +- OPT-125M (125m) +- OPT-350M (350m) +- OPT-700M (700m) +- OPT-1.3B (1.3b) +- OPT-2.7B (2.7b) +- OPT-3.5B (3.5b) +- OPT-5.5B (5.5b) +- OPT-6.7B (6.7b) +- OPT-10B (10b) +- OPT-13B (13b) + +We also provide various training strategies: + +- gemini: ColossalAI GeminiPlugin with `placement_policy="cuda"`, like zero3 +- gemini_auto: ColossalAI GeminiPlugin with `placement_policy="cpu"`, like zero3-offload +- zero2: ColossalAI zero2 +- zero2_cpu: ColossalAI zero2-offload +- 3d: ColossalAI HybridParallelPlugin with TP, DP support + +## How to Run +```bash +cd ../tests +# Prepare data for benchmark +SFT_DATASET=/path/to/sft/data/ \ +PROMPT_DATASET=/path/to/prompt/data/ \ +PRETRAIN_DATASET=/path/to/ptx/data/ \ +PREFERENCE_DATASET=/path/to/preference/data \ +./test_data_preparation.sh +# Start benchmark +./benchmark_ppo.sh +``` diff --git a/applications/ColossalChat/benchmarks/benchmark_memory_consumption.txt b/applications/ColossalChat/benchmarks/benchmark_memory_consumption.txt new file mode 100644 index 000000000..049285552 --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_memory_consumption.txt @@ -0,0 +1,4 @@ +Model=Opt-125m; lora_rank=0; plugin=zero2 +Max CUDA memory usage: 26123.16 MB +Model=Opt-125m; lora_rank=0; plugin=zero2 +Max CUDA memory usage: 26123.91 MB diff --git a/applications/ColossalChat/benchmarks/benchmark_performance_summarization.txt b/applications/ColossalChat/benchmarks/benchmark_performance_summarization.txt new file mode 100644 index 000000000..b2a1ff1d7 --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_performance_summarization.txt @@ -0,0 +1,16 @@ +facebook/opt-125m; 0; zero2 +Performance summary: +Generate 768 samples, throughput: 188.48 samples/s, TFLOPS per GPU: 361.23 +Train 768 samples, throughput: 448.38 samples/s, TFLOPS per GPU: 82.84 +Overall throughput: 118.42 samples/s +Overall time per sample: 0.01 s +Make experience time per sample: 0.01 s, 62.83% +Learn time per sample: 0.00 s, 26.41% +facebook/opt-125m; 0; zero2 +Performance summary: +Generate 768 samples, throughput: 26.32 samples/s, TFLOPS per GPU: 50.45 +Train 768 samples, throughput: 71.15 samples/s, TFLOPS per GPU: 13.14 +Overall throughput: 18.86 samples/s +Overall time per sample: 0.05 s +Make experience time per sample: 0.04 s, 71.66% +Learn time per sample: 0.01 s, 26.51% diff --git a/applications/ColossalChat/benchmarks/benchmark_ppo.py b/applications/ColossalChat/benchmarks/benchmark_ppo.py new file mode 100644 index 000000000..e1b7a313f --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_ppo.py @@ -0,0 +1,523 @@ +""" +For becnhmarking ppo. Mudified from examples/training_scripts/train_ppo.py +""" + +import argparse +import json +import os +import resource +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from coati.dataset import ( + DataCollatorForPromptDataset, + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_conversation_template, + setup_distributed_dataloader, +) +from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout +from coati.trainer import PPOTrainer +from coati.trainer.callbacks import PerformanceEvaluator +from coati.trainer.utils import is_rank_0 +from coati.utils import load_checkpoint, replace_with_flash_attention +from transformers import AutoTokenizer, OPTForCausalLM +from transformers.models.opt.configuration_opt import OPTConfig + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +def get_model_numel(model: torch.nn.Module, plugin: str, tp: int) -> int: + numel = sum(p.numel() for p in model.parameters()) + if plugin == "3d" and tp > 1: + numel *= dist.get_world_size() + return numel + + +def get_gpt_config(model_name: str) -> OPTConfig: + model_map = { + "125m": OPTConfig.from_pretrained("facebook/opt-125m"), + "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), + "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), + "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"), + "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"), + "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), + "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), + "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"), + "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), + "13b": OPTConfig.from_pretrained("facebook/opt-13b"), + } + try: + return model_map[model_name] + except KeyError: + raise ValueError(f'Unknown model "{model_name}"') + + +def benchmark_train(args): + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + init_ctx = LazyInitContext(default_device=get_current_device()) if "gemini" in args.plugin else nullcontext() + + booster_policy = None + with init_ctx: + actor = OPTForCausalLM(config=get_gpt_config(args.pretrain)) + # Disable dropout + disable_dropout(actor) + ref_model = OPTForCausalLM(config=get_gpt_config(args.pretrain)) + reward_model = RewardModel(config=get_gpt_config("350m")) + critic = Critic(config=get_gpt_config("350m")) + disable_dropout(critic) + + actor_numel = get_model_numel(actor, args.plugin, args.tp) + critic_numel = get_model_numel(critic, args.plugin, args.tp) + initial_model_numel = get_model_numel(ref_model, args.plugin, args.tp) + reward_model_numel = get_model_numel(reward_model, args.plugin, args.tp) + + performance_evaluator = PerformanceEvaluator( + actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=2, + train_config={"model": "facebook/opt-" + args.pretrain, "lora_rank": args.lora_rank, "plugin": args.plugin}, + save_path="./benchmark_performance_summarization.txt", + ) + + if args.tp > 1: + if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]: + raise ValueError("Reward model and critic model must have the same architecture") + if reward_model.model.config.architectures[0] == "BloomForCausalLM": + from colossalai.shardformer.policies.bloom import BloomPolicy + + booster_policy = BloomPolicy() + elif reward_model.model.config.architectures[0] == "LlamaForCausalLM": + from colossalai.shardformer.policies.llama import LlamaPolicy + + booster_policy = LlamaPolicy() + elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel": + from colossalai.shardformer.policies.gpt2 import GPT2Policy + + booster_policy = GPT2Policy() + elif reward_model.model.config.architectures[0] == "ChatGLMModel": + from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy + + booster_policy = ChatGLMPolicy() + elif reward_model.model.config.architectures[0] == "OPTForCausalLM": + from colossalai.shardformer.policies.opt import OPTPolicy + + booster_policy = OPTPolicy() + else: + raise ValueError("Unknown model architecture for policy") + + if args.lora_rank > 0: + actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias) + critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias) + + if args.grad_checkpoint and args.lora_rank == 0: + actor.gradient_checkpointing_enable() + critic.model.gradient_checkpointing_enable() + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + elif args.lora_rank > 0: + coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") + + if args.use_flash_attn: + replace_with_flash_attention(model=actor) + replace_with_flash_attention(model=critic) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + + # configure tokenizer + tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) + if os.path.exists(args.conversation_template_config): + conversation_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8")) + conversation_template = setup_conversation_template( + tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config + ) + stop_token_ids = ( + conversation_template.assistant_line_end if len(conversation_template.assistant_line_end) > 0 else None + ) + else: + raise ValueError("Conversation template config is not provided or incorrect") + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + tokenizer.padding_side = "left" # left padding for generation (online learning) + + # configure generation config + actor.generation_config.update( + pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id + ) + + # configure optimizer + coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}") + actor_optim = HybridAdam( + model_params=actor.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + coordinator.print_on_master(f"setting up optimizer for critic: lr={args.lr}, weight_decay={args.weight_decay}") + critic_optim = HybridAdam( + model_params=critic.parameters(), + lr=args.critic_lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # configure dataset + coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}") + mode_map = {"train": "train", "valid": "validation", "test": "test"} + train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map) + coordinator.print_on_master(f"prompt dataset size: {len(train_prompt_dataset)}") + data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len) + train_prompt_dataloader = setup_distributed_dataloader( + dataset=train_prompt_dataset, + batch_size=args.experience_batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + use_tp=args.tp > 1, + ) + + if len(args.pretrain_dataset) > 0: + train_pretrain_dataset = load_tokenized_dataset( + dataset_paths=args.pretrain_dataset, mode="train", mode_map=mode_map + ) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) + train_pretrain_dataloader = setup_distributed_dataloader( + dataset=train_pretrain_dataset, + batch_size=args.ptx_batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + use_tp=args.tp > 1, + ) + else: + train_pretrain_dataloader = None + + if args.warmup_steps is None: + args.warmup_steps = int(0.025 * args.num_episodes) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + actor_lr_scheduler = CosineAnnealingWarmupLR( + optimizer=actor_optim, + total_steps=args.num_episodes, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + critic_lr_scheduler = CosineAnnealingWarmupLR( + optimizer=critic_optim, + total_steps=args.num_episodes, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=0, + precision=args.mixed_precision, + ) + custom_plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=0, + precision=args.mixed_precision, + custom_policy=booster_policy, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + if args.plugin != "3d": + custom_plugin = plugin + + actor_booster = Booster(plugin=plugin) + ref_booster = Booster(plugin=plugin) + rm_booster = Booster(plugin=custom_plugin) + critic_booster = Booster(plugin=custom_plugin) + + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost( + model=actor, + optimizer=actor_optim, + lr_scheduler=actor_lr_scheduler, + dataloader=train_prompt_dataloader, + ) + + critic, critic_optim, _, _, critic_lr_scheduler = critic_booster.boost( + model=critic, + optimizer=critic_optim, + lr_scheduler=critic_lr_scheduler, + dataloader=train_prompt_dataloader, + ) + reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader) + ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader) + + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + sampler_start_idx = 0 + start_step = 0 + + if args.rm_checkpoint_path is not None: + if "modeling" in args.rm_checkpoint_path: + rm_booster.load_model(reward_model, args.rm_checkpoint_path) + else: + _, _, _ = load_checkpoint( + load_dir=args.rm_checkpoint_path, + booster=rm_booster, + model=reward_model, + optimizer=None, + lr_scheduler=None, + ) + coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}") + + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + actor_booster.load_model(actor, args.checkpoint_path) + ref_booster.load_model(ref_model, args.checkpoint_path) + coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}") + else: + _, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=actor_booster, + model=actor, + optimizer=actor_optim, + lr_scheduler=actor_lr_scheduler, + ) + _, _, _ = load_checkpoint( + load_dir=args.checkpoint_path, + booster=ref_booster, + model=ref_model, + optimizer=critic_optim, + lr_scheduler=critic_lr_scheduler, + ) + assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler) + train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + if args.critic_checkpoint_path is not None: + if "modeling" in args.critic_checkpoint_path: + critic_booster.load_model(critic, args.critic_checkpoint_path) + else: + _, _, _ = load_checkpoint( + load_dir=args.critic_checkpoint_path, + booster=critic_booster, + model=critic, + optimizer=critic_optim, + lr_scheduler=critic_lr_scheduler, + ) + coordinator.print_on_master(f"Loaded critic checkpoint {args.critic_checkpoint_path}") + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + # configure trainer + trainer = PPOTrainer( + actor_booster, + critic_booster, + actor, + critic, + reward_model, + ref_model, + actor_optim, + critic_optim, + actor_lr_scheduler, + critic_lr_scheduler, + tokenizer=tokenizer, + stop_token_ids=stop_token_ids, + kl_coef=args.kl_coef, + ptx_coef=args.ptx_coef, + train_batch_size=args.train_batch_size, + buffer_limit=args.num_collect_steps * args.experience_batch_size, + max_length=args.max_length, + max_new_tokens=args.max_seq_len, + use_cache=True, + do_sample=True, + temperature=0.7, + accumulation_steps=args.accumulation_steps, + save_dir=args.save_path, + save_interval=args.save_interval, + top_k=50, + use_tp=args.tp > 1, + offload_inference_models="gemini" not in args.plugin, + callbacks=[performance_evaluator], + coordinator=coordinator, + ) + + trainer.fit( + num_episodes=args.num_episodes, + num_collect_steps=args.num_collect_steps, + num_update_steps=args.num_update_steps, + prompt_dataloader=train_prompt_dataloader, + pretrain_dataloader=train_pretrain_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + actor.eval() + critic.eval() + # save model checkpoint after fitting on only rank0 + coordinator.print_on_master("Start saving final actor model checkpoint") + actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}" + ) + coordinator.print_on_master("Start saving final critic model checkpoint") + critic_booster.save_model(critic, os.path.join(trainer.critic_save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final critic model checkpoint at episodes {args.num_episodes} at folder {args.save_path}" + ) + memory_consumption = torch.cuda.max_memory_allocated() / 1024**2 + if is_rank_0(): + with open("./benchmark_memory_consumption.txt", "a+") as f: + f.write( + f"Model=Opt-{args.pretrain}; lora_rank={args.lora_rank}; plugin={args.plugin}\nMax CUDA memory usage: {memory_consumption:.2f} MB\n" + ) + coordinator.print_on_master(f"Max CUDA memory usage: {memory_consumption:.2f} MB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prompt_dataset", nargs="+", default=[]) + parser.add_argument("--pretrain_dataset", nargs="+", default=[]) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + help="Choose which plugin to use", + ) + parser.add_argument( + "--conversation_template_config", + type=str, + default=None, + help="Path \ + to save conversation template config files.", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--critic_checkpoint_path", type=str, default=None) + parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path") + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--num_collect_steps", type=int, default=2) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--save_interval", type=int, default=1000) + parser.add_argument("--train_batch_size", type=int, default=16) + parser.add_argument("--experience_batch_size", type=int, default=16) + parser.add_argument("--ptx_batch_size", type=int, default=1) + parser.add_argument("--lora_train_bias", type=str, default="none") + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) + parser.add_argument("--lr", type=float, default=9e-6) + parser.add_argument("--critic_lr", type=float, default=9e-6) + parser.add_argument("--kl_coef", type=float, default=0.1) + parser.add_argument("--ptx_coef", type=float, default=0.0) + parser.add_argument("--max_length", type=int, default=512) + parser.add_argument("--max_seq_len", type=int, default=256) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + args = parser.parse_args() + benchmark_train(args) diff --git a/applications/ColossalChat/benchmarks/benchmark_ppo.sh b/applications/ColossalChat/benchmarks/benchmark_ppo.sh new file mode 100755 index 000000000..e88757659 --- /dev/null +++ b/applications/ColossalChat/benchmarks/benchmark_ppo.sh @@ -0,0 +1,119 @@ +#!/usr/bin/env bash + +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 8 + +set -xu + +NUM_RETRY=3 +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +EXAMPLES_DIR=$BASE_DIR/examples +TEMP_DIR=$BASE_DIR/temp +MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models +MODELS_DIR=$TEMP_DIR/models_config +# To benchmark different models, change the following line +# MODELS=('125m' '350m' '700m' '1.3b' '2.7b' '3.5b' '5.5b' '6.7b' '10b' '13b') +MODELS=('125m') +# To benchmark different strategies, change the following line +# PLUGINS=('zero2', 'zero2_cpu', '3d') +PLUGINS=('zero2') +LORA_RANK=('0') + +export OMP_NUM_THREADS=8 + +rm ./benchmark_memory_consumption.txt +rm ./benchmark_performance_summarization.txt + +# install requirements +pip install -r $EXAMPLES_DIR/requirements.txt + +random_choice() { + local arr=("$@") + local len=${#arr[@]} + local idx=$((RANDOM % len)) + echo ${arr[$idx]} +} + +echo "[Test]: testing ppo ..." + +SKIPPED_TESTS=( +) + +GRAD_CKPTS=('' '--grad_checkpoint') +GRAD_CKPTS=('') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + plugins=($(shuf -e "${PLUGINS[@]}")) + for plugin in ${plugins[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$model + tokenizer_dir="facebook/opt-125m" + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + if [[ $plugin == "3d" ]]; then + tp='4' + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a prompt_dataset=() + for split in $(seq -f "%05g" 0 9); do + prompt_dataset+=("$TEMP_DIR/benchmark/arrow/part-$split") + done + colossalai run --nproc_per_node 8 --master_port 28547 $BASE_DIR/benchmarks/benchmark_ppo.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --prompt_dataset ${prompt_dataset[@]} \ + --ptx_coef 0 \ + --save_path $MODEL_SAVE_PATH \ + --conversation_template_config ./Opt.json \ + --lora_rank $lora_rank \ + --plugin $plugin \ + --num_episodes 5 \ + --num_collect_steps 1 \ + --num_update_steps 1 \ + --max_seq_len 128 \ + --max_length 512 \ + --experience_batch_size 32 \ + --train_batch_size 32 \ + --accumulation_steps 1 \ + --lr 9e-6 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --use_flash_attn \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt + passed=$? + if [ $passed -eq 0 ]; then + rm -rf $MODEL_SAVE_PATH/* + rm -rf $MODELS_DIR/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done diff --git a/applications/ColossalChat/benchmarks/data_preparation.sh b/applications/ColossalChat/benchmarks/data_preparation.sh new file mode 100755 index 000000000..ca2986be4 --- /dev/null +++ b/applications/ColossalChat/benchmarks/data_preparation.sh @@ -0,0 +1,16 @@ +SAVE_DIR="" + + +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +EXAMPLES_DIR=$BASE_DIR/examples +SAVE_DIR=$BASE_DIR/temp/benchmark + +rm -rf $SAVE_DIR + +python $EXAMPLES_DIR/data_preparation_scripts/prepare_prompt_dataset.py --data_input_dirs "/home/yeanbang/data/dataset/sft_data/alpaca/data_preprocessed/train" \ + --conversation_template_config ./Opt.json \ + --tokenizer_dir "facebook/opt-125m" \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow \ + --num_samples_per_datafile 30 diff --git a/applications/Chat/benchmarks/ray/1mmt_dummy.py b/applications/ColossalChat/benchmarks/ray/1mmt_dummy.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/benchmarks/ray/1mmt_dummy.py rename to applications/ColossalChat/benchmarks/ray/1mmt_dummy.py diff --git a/applications/Chat/benchmarks/ray/mmmt_dummy.py b/applications/ColossalChat/benchmarks/ray/mmmt_dummy.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/benchmarks/ray/mmmt_dummy.py rename to applications/ColossalChat/benchmarks/ray/mmmt_dummy.py diff --git a/applications/Chat/coati/__init__.py b/applications/ColossalChat/coati/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/__init__.py rename to applications/ColossalChat/coati/__init__.py diff --git a/applications/ColossalChat/coati/dataset/__init__.py b/applications/ColossalChat/coati/dataset/__init__.py new file mode 100755 index 000000000..e216c37e1 --- /dev/null +++ b/applications/ColossalChat/coati/dataset/__init__.py @@ -0,0 +1,26 @@ +from .conversation import Conversation, setup_conversation_template +from .loader import ( + DataCollatorForPreferenceDataset, + DataCollatorForPromptDataset, + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_distributed_dataloader, +) +from .tokenization_utils import supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf + +__all__ = [ + "tokenize_prompt_dataset", + "DataCollatorForPromptDataset", + "is_rank_0", + "DataCollatorForPreferenceDataset", + "DataCollatorForSupervisedDataset", + "StatefulDistributedSampler", + "load_tokenized_dataset", + "setup_distributed_dataloader", + "supervised_tokenize_pretrain", + "supervised_tokenize_sft", + "tokenize_rlhf", + "setup_conversation_template", + "Conversation", +] diff --git a/applications/ColossalChat/coati/dataset/conversation.py b/applications/ColossalChat/coati/dataset/conversation.py new file mode 100755 index 000000000..15a33be93 --- /dev/null +++ b/applications/ColossalChat/coati/dataset/conversation.py @@ -0,0 +1,143 @@ +import dataclasses +import json +import os +from typing import Any, Dict, List + +import torch.distributed as dist +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +@dataclasses.dataclass +class Conversation: + tokenizer: PreTrainedTokenizer + system_message: str + chat_template: str + stop_ids: List[int] + + @classmethod + def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict): + """ + Setup the conversation template from config + """ + tokenizer.chat_template = config["chat_template"] + conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"]) + conv.clear() + return conv + + def clear(self): + self.messages = [] + + @classmethod + def get_conversation_template_keys(cls): + return ["system_message", "chat_template"] + + def __str__(self): + return json.dumps( + {k: self.__dict__[k] for k in self.__dict__ if k not in ["tokenizer", "messages"]}, + ensure_ascii=False, + indent=4, + ) + + def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any: + """ + Retrieves the prompt for the conversation. + + Args: + length (int, optional): The number of messages to include in the prompt. Defaults to None. + get_seps_info (bool, optional): Whether to include separator information in the output. Defaults to False. + add_generation_prompt (bool, optional): Whether to add the assistant line start token in generation (for generation only). Defaults to False. + + Returns: + str or tuple: The prompt string if get_seps_info is False, otherwise a tuple containing the prompt string and separator information. + """ + + if length is None: + length = len(self.messages) + + assert length <= len(self.messages) + if self.system_message is not None: + messages = [{"role": "system", "content": self.system_message}] + self.messages[:length] + else: + messages = self.messages[:length] + prompt = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=add_generation_prompt + ) + return prompt + + def save_prompt(self): + return self.get_prompt() + + def append_message(self, role: str, message: str): + """ + Append a message to the conversation. + + Args: + role (str): The role of the message sender. Must be either 'user' or 'assistant'. + message (str): The content of the message. + + Raises: + AssertionError: If the role is not 'user' or 'assistant'. + """ + assert role in ["user", "assistant"] + self.messages.append({"role": role, "content": message}) + + def copy(self): + return Conversation(tokenizer=self.tokenizer, chat_template=self.chat_template) + + +def setup_conversation_template( + tokenizer: PreTrainedTokenizer, chat_template_config: Dict = None, save_path: str = None +) -> Conversation: + """ + Setup the conversation template, if chat_template is given, will replace the default chat_template of the tokenizer + with it. Otherwise, the default chat_template will be used. If the tokenizer doesn't have a default chat_template, + raise error to remind the user to set it manually. + + Args: + tokenizer: The tokenizer to use + chat_template_config: + { + "system_message": str The system message to use + "chat_template": str The chat_template to use, if can be a chat_template, a huggingface model path or a local model. + if a huggeface model path or a local model, the chat_template will be loaded from the model's tokenizer's default chat template. + "stop_ids": List[int], the token ids used to terminate generation. You need to provide this for ppo training and generation. + } + """ + if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]): + # Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually + if "system_message" not in chat_template_config: + logger.warning("No system message is provided, will not use system message.") + if "chat_template" not in chat_template_config: + logger.warning("No chat_template is provided, will try to load it from the tokenizer.") + if tokenizer.chat_template != None: + chat_template_config["chat_template"] = tokenizer.chat_template + else: + raise ValueError( + f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually." + ) + else: + try: + tokenizer = AutoTokenizer.from_pretrained(chat_template_config["chat_template"]) + if tokenizer.chat_template != None: + chat_template_config["chat_template"] = tokenizer.chat_template + else: + raise ValueError( + f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually." + ) + logger.warning( + f"chat_template is provided as a local model path or huggingface model path, loaded chat_template from \"{chat_template_config['chat_template']}\"." + ) + except OSError: + pass + except ValueError as e: + raise ValueError(e) + if not dist.is_initialized() or dist.get_rank() == 0: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(save_path, "w", encoding="utf8") as f: + logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.") + json.dump(chat_template_config, f, indent=4, ensure_ascii=False) + return Conversation.from_config(tokenizer, chat_template_config) diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py new file mode 100755 index 000000000..93cc1dab8 --- /dev/null +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Dataloader for sft, dpo, ppo +""" + +import math +import os +import random +from dataclasses import dataclass +from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from coati.dataset.utils import chuncate_sequence, pad_to_max_len +from datasets import Dataset as HFDataset +from datasets import dataset_dict, load_from_disk +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler +from transformers.tokenization_utils import PreTrainedTokenizer + +DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] +PathType = Union[str, os.PathLike] + + +def load_tokenized_dataset( + dataset_paths: Union[PathType, List[PathType]], mode: str = "train", **kwargs +) -> Optional[DatasetType]: + """ + Load pre-tokenized dataset. + Each instance of dataset is a dictionary with + `{'input_ids': List[int], 'labels': List[int], sequence: str}` format. + """ + mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"}) + assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}" + + if isinstance(dataset_paths, (str, os.PathLike)): + dataset_paths = [dataset_paths] + + datasets = [] # `List[datasets.dataset_dict.Dataset]` + for ds_path in dataset_paths: + ds_path = os.path.abspath(ds_path) + assert os.path.exists(ds_path), f"Not existed file path {ds_path}" + ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False) + if isinstance(ds_dict, HFDataset): + datasets.append(ds_dict) + else: + if mode_map[mode] in ds_dict: + datasets.append(ds_dict[mode_map[mode]]) + if len(datasets) == 0: + return None + if len(datasets) == 1: + return datasets.pop() + return ConcatDataset(datasets=datasets) + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """ + Collate instances for supervised dataset. + Each instance is a tokenized dictionary with fields + `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str). + """ + + tokenizer: PreTrainedTokenizer + max_length: int = 4096 + ignore_index: int = -100 + + def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: + """ + + Args: + instances (`Sequence[Dict[str, List[int]]]`): + Mini-batch samples, each sample is stored in an individual dictionary. + + Returns: + (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: + `input_ids`: `torch.Tensor` of shape (bsz, max_len); + `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); + `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. + """ + assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( + f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " + f"but now `{self.tokenizer.pad_token_id}`" + ) + + # `List[torch.Tensor]` + batch_input_ids = [ + torch.LongTensor(instance["input_ids"][: self.max_length]) + if len(instance["input_ids"]) > self.max_length + else torch.LongTensor(instance["input_ids"]) + for instance in instances + ] + batch_labels = [ + torch.LongTensor(instance["labels"][: self.max_length]) + if len(instance["labels"]) > self.max_length + else torch.LongTensor(instance["labels"]) + for instance in instances + ] + if self.tokenizer.padding_side == "right": + input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=batch_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + labels = torch.nn.utils.rnn.pad_sequence( + sequences=batch_labels, + batch_first=True, + padding_value=self.ignore_index, + ) # (bsz, max_len) + # pad to max + to_pad = self.max_length - input_ids.size(1) + input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + labels = F.pad(labels, (0, to_pad), value=self.ignore_index) + elif self.tokenizer.padding_side == "left": + reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids] + reversed_input_ids = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) # (bsz, max_len) + input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len) + reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels] + reversed_labels = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_labels, + batch_first=True, + padding_value=self.ignore_index, + ) # (bsz, max_len) + labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len) + else: + raise RuntimeError( + f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, " + f"but now `{self.tokenizer.padding_side}`" + ) + + attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len) + + return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + +@dataclass +class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset): + def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: + """ + + Args: + instances (`Sequence[Dict[str, List[int]]]`): + Mini-batch samples, each sample is stored in an individual dictionary. + + Returns: + (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: + `input_ids`: `torch.Tensor` of shape (bsz, max_len); + `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); + """ + instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances] + ret = super().__call__(instances=instances) + input_ids = F.pad( + ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id + ) + attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@dataclass +class DataCollatorForPreferenceDataset(object): + """ + Collate instances for supervised dataset. + Each instance is a tokenized dictionary with fields + `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str). + """ + + tokenizer: PreTrainedTokenizer + max_length: int = 4096 + + def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: + """ + + Args: + instances (`Sequence[Dict[str, List[int]]]`): + Mini-batch samples, each sample is stored in an individual dictionary. + + Returns: + (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`: + `input_ids`: `torch.Tensor` of shape (bsz, max_len); + `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len); + `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`. + """ + assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, ( + f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, " + f"but now `{self.tokenizer.pad_token_id}`" + ) + + ( + chosen_input_ids, + chosen_loss_mask, # [batch_size * seq_len] + reject_input_ids, + reject_loss_mask, + ) = ( + chuncate_sequence([ins["chosen_input_ids"] for ins in instances], self.max_length, torch.int64), + chuncate_sequence([ins["chosen_loss_mask"] for ins in instances], self.max_length, torch.bool), + chuncate_sequence([ins["rejected_input_ids"] for ins in instances], self.max_length, torch.int64), + chuncate_sequence([ins["rejected_loss_mask"] for ins in instances], self.max_length, torch.bool), + ) + + padding_side = self.tokenizer.padding_side + chosen_attention_mask = [torch.ones_like(seq).bool() for seq in chosen_input_ids] + reject_attention_mask = [torch.ones_like(seq).bool() for seq in reject_input_ids] + + ( + chosen_input_ids, + chosen_attention_mask, + chosen_loss_mask, + reject_input_ids, + reject_attention_mask, + reject_loss_mask, + ) = ( + pad_to_max_len(chosen_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side), + pad_to_max_len(chosen_attention_mask, self.max_length, False, padding_side=padding_side), + pad_to_max_len(chosen_loss_mask, self.max_length, False, padding_side=padding_side), + pad_to_max_len(reject_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side), + pad_to_max_len(reject_attention_mask, self.max_length, False, padding_side=padding_side), + pad_to_max_len(reject_loss_mask, self.max_length, False, padding_side=padding_side), + ) + + return dict( + chosen_input_ids=chosen_input_ids, + chosen_attention_mask=chosen_attention_mask, + chosen_loss_mask=chosen_loss_mask, + reject_input_ids=reject_input_ids, + reject_attention_mask=reject_attention_mask, + reject_loss_mask=reject_loss_mask, + ) + + +class StatefulDistributedSampler(DistributedSampler): + """ + Stateful distributed sampler for multi-stage training. + """ + + def __init__( + self, + dataset: DatasetType, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + use_tp: Optional[bool] = False, + ) -> None: + if not use_tp: + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + else: + # adapted from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L62 + # TODO: support tp_group>1. will fix it later + num_replicas = 1 + if rank is None: + rank = dist.get_rank() + if rank < 0: + raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, 0]") + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + self.start_index = 0 + self.use_tp = use_tp + + def __iter__(self) -> Iterator: + if self.use_tp: + # TODO Add support for tp_group not equal to 1 + pass + # adpated from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L96 + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[ + : self.total_size : self.num_replicas + ] # num_replicas=tp_group=1, we only support tp_group==1 for now + assert len(indices) == self.num_samples + + return iter(indices) + + else: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def setup_distributed_dataloader( + dataset: DatasetType, + batch_size: int = 1, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None, + process_group: Optional[ProcessGroup] = None, + use_tp: Optional[bool] = False, + **kwargs, +) -> DataLoader: + """ + Setup dataloader for distributed training. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset=dataset, + num_replicas=process_group.size() if not use_tp else 1, + rank=process_group.rank(), + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + use_tp=use_tp, + ) + + # Deterministic dataloader + def seed_worker(worker_id: int) -> None: + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=drop_last, + worker_init_fn=seed_worker, + **_kwargs, + ) diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py new file mode 100755 index 000000000..7606bc2a9 --- /dev/null +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +tokenization utils for constructing dataset for ppo, dpo, sft, rm +""" + +import warnings +from copy import deepcopy +from typing import Any, Dict, List, Union + +from coati.dataset.conversation import Conversation +from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate +from datasets import dataset_dict +from torch.utils.data import ConcatDataset, Dataset +from transformers import PreTrainedTokenizer + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + +IGNORE_INDEX = -100 + +DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] + + +def supervised_tokenize_sft( + data_point: Dict[str, str], + tokenizer: PreTrainedTokenizer, + conversation_template: Conversation = None, + ignore_index: int = None, + max_length: int = 4096, +) -> Dict[str, Union[int, str, List[int]]]: + """ + A tokenization function to tokenize an original pretraining data point as following + and calculate corresponding labels for sft training: + "Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line end]Something here" + ^ + end_of_system_line_position + + Args: + data_point: the data point of the following format + {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} + tokenizer: the tokenizer whose + conversation_template: the conversation template to apply + ignore_index: the ignore index when calculate loss during training + max_length: the maximum context length + """ + + if ignore_index is None: + ignore_index = IGNORE_INDEX + + messages = data_point["messages"] + template = deepcopy(conversation_template) + template.messages = [] + + for mess in messages: + from_str = mess["from"] + if from_str.lower() == "human": + from_str = "user" + elif from_str.lower() == "assistant": + from_str = "assistant" + else: + raise ValueError(f"Unsupported role {from_str.lower()}") + + template.append_message(from_str, mess["content"]) + + if len(template.messages) % 2 != 0: + template.messages = template.messages[0:-1] + + # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time. + turns = [i for i in range(1, len(messages) // 2 + 1)] + + lo, hi = 0, len(turns) + while lo < hi: + mid = (lo + hi) // 2 + if max_length - 1 < len( + tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0] + ): + hi = mid + else: + lo = mid + 1 + target_turn_index = lo + + # The tokenized length for first turn already exceeds `max_length - 1`. + if target_turn_index - 1 < 0: + warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.") + return dict( + input_ids=None, + labels=None, + inputs_decode=None, + labels_decode=None, + seq_length=None, + seq_category=None, + ) + + target_turn = turns[target_turn_index - 1] + prompt = template.get_prompt(2 * target_turn) + chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt) + tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) + + labels = [ignore_index] * len(tokenized) + label_decode = [] + for start, end in zip(starts, ends): + if end == len(tokenized): + tokenized = tokenized + [tokenizer.eos_token_id] + labels = labels + [ignore_index] + labels[start : end + 1] = tokenized[start : end + 1] + label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False)) + + if tokenizer.bos_token_id is not None: + if tokenized[0] != tokenizer.bos_token_id: + tokenized = [tokenizer.bos_token_id] + tokenized + labels = [ignore_index] + labels + + if tokenizer.eos_token_id is not None: + # Force to add eos token at the end of the tokenized sequence + if tokenized[-1] != tokenizer.eos_token_id: + tokenized = tokenized + [tokenizer.eos_token_id] + labels = labels + [tokenizer.eos_token_id] + else: + labels[-1] = tokenizer.eos_token_id + + # For some model without bos/eos may raise the following errors + try: + inputs_decode = tokenizer.decode(tokenized) + except TypeError as e: + raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}") + + # Check if all labels are ignored, this may happen when the tokenized length is too long + if labels.count(ignore_index) == len(labels): + return dict( + input_ids=None, + labels=None, + inputs_decode=None, + labels_decode=None, + seq_length=None, + seq_category=None, + ) + + return dict( + input_ids=tokenized, + labels=labels, + inputs_decode=inputs_decode, + labels_decode=label_decode, + seq_length=len(tokenized), + seq_category=data_point["category"] if "category" in data_point else "None", + ) + + +def tokenize_prompt_dataset( + data_point: Dict[str, str], + tokenizer: PreTrainedTokenizer, + conversation_template: Conversation = None, + ignore_index: int = None, + max_length: int = 4096, +) -> Dict[str, Union[int, str, List[int]]]: + """ + A tokenization function to tokenize an original pretraining data point as following for ppo training: + "Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]" + Args: + data_point: the data point of the following format + {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]} + tokenizer: the tokenizer whose + conversation_template: the conversation template to apply + ignore_index: the ignore index when calculate loss during training + max_length: the maximum context length + """ + if ignore_index is None: + ignore_index = IGNORE_INDEX + + messages = data_point["messages"] + template = deepcopy(conversation_template) + template.messages = [] + + for mess in messages: + from_str = mess["from"] + if from_str.lower() == "human": + from_str = "user" + elif from_str.lower() == "assistant": + from_str = "assistant" + else: + raise ValueError(f"Unsupported role {from_str.lower()}") + + template.append_message(from_str, mess["content"]) + + # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time. + target_turn = len(template.messages) + if target_turn % 2 != 1: + # exclude the answer if provided. keep only the prompt + target_turn = target_turn - 1 + + # Prepare data + prompt = template.get_prompt(target_turn, add_generation_prompt=True) + tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0] + if tokenizer.bos_token_id is not None: + if tokenized[0] != tokenizer.bos_token_id: + tokenized = [tokenizer.bos_token_id] + tokenized + + # Skip overlength data + if max_length - 1 < len(tokenized): + return dict( + input_ids=None, + inputs_decode=None, + seq_length=None, + seq_category=None, + ) + + # `inputs_decode` can be used to check whether the tokenization method is true. + return dict( + input_ids=tokenized, + inputs_decode=tokenizer.decode(tokenized), + seq_length=len(tokenized), + seq_category=data_point["category"] if "category" in data_point else "None", + ) + + +def apply_rlhf_data_format( + template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False +): + target_turn = int(len(template.messages) / 2) + prompt = template.get_prompt(target_turn * 2) + chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt) + tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss) + loss_mask = [0] * len(tokenized) + mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id + if mask_token is None: + mask_token = 1 # If the tokenizer doesn't have eos_token or pad_token: Qwen + + label_decode = [] + for start, end in zip(starts[-1:], ends[-1:]): + # only the last round (chosen/rejected) counts + if end == len(tokenized): + tokenized = tokenized + [tokenizer.eos_token_id] + loss_mask = loss_mask + [1] + loss_mask[start : end + 1] = [1] * len(loss_mask[start : end + 1]) + label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False)) + if tokenizer.bos_token_id is not None: + if tokenized[0] != tokenizer.bos_token_id: + tokenized = [tokenizer.bos_token_id] + tokenized + loss_mask = [0] + loss_mask + + if tokenizer.eos_token_id is not None: + # Force to add eos token at the end of the tokenized sequence + if tokenized[-1] != tokenizer.eos_token_id: + tokenized = tokenized + [tokenizer.eos_token_id] + loss_mask = loss_mask + [1] + else: + loss_mask[-1] = 1 + + return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode} + + +def tokenize_rlhf( + data_point: Dict[str, str], + tokenizer: PreTrainedTokenizer, + conversation_template: Conversation = None, + ignore_index: int = None, + max_length: int = 4096, +) -> Dict[str, Union[int, str, List[int]]]: + """ + A tokenization function to tokenize an original pretraining data point as following: + {"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}], + "chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}} + """ + if ignore_index is None: + ignore_index = IGNORE_INDEX + + context = data_point["context"] + template = deepcopy(conversation_template) + template.clear() + + for mess in context: + from_str = mess["from"] + if from_str.lower() == "human": + from_str = "user" + elif from_str.lower() == "assistant": + from_str = "assistant" + else: + raise ValueError(f"Unsupported role {from_str.lower()}") + + if len(template.messages) > 0 and from_str == template.messages[-1]["role"]: + # Concate adjacent message from the same role + template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"]) + else: + template.append_message(from_str, mess["content"]) + + if len(template.messages) % 2 != 1: + warnings.warn( + "Please make sure leading context starts and ends with a line from human\nLeading context: " + + str(template.messages) + ) + return dict( + chosen_input_ids=None, + chosen_loss_mask=None, + chosen_label_decode=None, + rejected_input_ids=None, + rejected_loss_mask=None, + rejected_label_decode=None, + ) + round_of_context = int((len(template.messages) - 1) / 2) + + assert context[-1]["from"].lower() == "human", "The last message in context should be from human." + chosen = deepcopy(template) + rejected = deepcopy(template) + + for round in range(len(data_point["chosen"])): + from_str = data_point["chosen"][round]["from"] + if from_str.lower() == "human": + from_str = "user" + elif from_str.lower() == "assistant": + from_str = "assistant" + else: + raise ValueError(f"Unsupported role {from_str.lower()}") + chosen.append_message(from_str, data_point["chosen"][round]["content"]) + + for round in range(len(data_point["rejected"])): + from_str = data_point["rejected"][round]["from"] + if from_str.lower() == "human": + from_str = "user" + elif from_str.lower() == "assistant": + from_str = "assistant" + else: + raise ValueError(f"Unsupported role {from_str.lower()}") + rejected.append_message(from_str, data_point["rejected"][round]["content"]) + + ( + chosen_input_ids, + chosen_loss_mask, + chosen_label_decode, + rejected_input_ids, + rejected_loss_mask, + rejected_label_decode, + ) = (None, None, None, None, None, None) + if ( + len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0]) + <= max_length - 1 + and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0]) + <= max_length - 1 + ): + chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context) + (chosen_input_ids, chosen_loss_mask, chosen_label_decode) = ( + chosen_data_packed["input_ids"], + chosen_data_packed["loss_mask"], + chosen_data_packed["label_decode"], + ) + + rejected_data_packed = apply_rlhf_data_format( + rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True + ) + (rejected_input_ids, rejected_loss_mask, rejected_label_decode) = ( + rejected_data_packed["input_ids"], + rejected_data_packed["loss_mask"], + rejected_data_packed["label_decode"], + ) + + # Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long + if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask): + return dict( + chosen_input_ids=None, + chosen_loss_mask=None, + chosen_label_decode=None, + rejected_input_ids=None, + rejected_loss_mask=None, + rejected_label_decode=None, + ) + + return { + "chosen_input_ids": chosen_input_ids, + "chosen_loss_mask": chosen_loss_mask, + "chosen_label_decode": chosen_label_decode, + "rejected_input_ids": rejected_input_ids, + "rejected_loss_mask": rejected_loss_mask, + "rejected_label_decode": rejected_label_decode, + } + else: + return dict( + chosen_input_ids=None, + chosen_loss_mask=None, + chosen_label_decode=None, + rejected_input_ids=None, + rejected_loss_mask=None, + rejected_label_decode=None, + ) diff --git a/applications/ColossalChat/coati/dataset/utils.py b/applications/ColossalChat/coati/dataset/utils.py new file mode 100755 index 000000000..ada2afef0 --- /dev/null +++ b/applications/ColossalChat/coati/dataset/utils.py @@ -0,0 +1,138 @@ +import io +import json +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from transformers import PreTrainedTokenizer + + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +def _make_r_io_base(f, mode: str): + if not isinstance(f, io.IOBase): + f = open(f, mode=mode) + return f + + +def jload(f, mode="r"): + """Load a .json file into a dictionary.""" + f = _make_r_io_base(f, mode) + jdict = json.load(f) + f.close() + return jdict + + +def read_string_by_schema(data: Dict[str, Any], schema: str) -> str: + """ + Read a feild of the dataset be schema + Args: + data: Dict[str, Any] + schema: cascaded feild names seperated by '.'. e.g. person.name.first will access data['person']['name']['first'] + """ + keys = schema.split(".") + result = data + for key in keys: + result = result.get(key, None) + if result is None: + return "" + assert isinstance(result, str), f"dataset element is not a string: {result}" + return result + + +def pad_to_max_len( + sequence: List[torch.Tensor], max_length: int, padding_value: int, batch_first: bool = True, padding_side="left" +): + """ + Args: + sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True + """ + if padding_side == "left": + reversed_sequence = [seq.flip(dims=(0,)) for seq in sequence] + padded = torch.nn.utils.rnn.pad_sequence( + sequences=reversed_sequence, batch_first=batch_first, padding_value=padding_value + ) + to_pad = max_length - padded.size(1) + padded = F.pad(padded, (0, to_pad), value=padding_value) + return torch.flip(padded, dims=(1,)) + elif padding_side == "right": + padded = torch.nn.utils.rnn.pad_sequence( + sequences=sequence, batch_first=batch_first, padding_value=padding_value + ) + to_pad = max_length - padded.size(1) + return F.pad(padded, (0, to_pad), value=padding_value) + else: + raise RuntimeError(f"`padding_side` can only be `left` or `right`, " f"but now `{padding_side}`") + + +def chuncate_sequence(sequence: List[torch.Tensor], max_length: int, dtype: Any): + """ + Args: + sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True + """ + return [ + torch.Tensor(seq[:max_length]).to(dtype) if len(seq) > max_length else torch.Tensor(seq).to(dtype) + for seq in sequence + ] + + +def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, start_index: int = 0) -> int: + if subseq is None: + return 0 + for i in range(start_index, len(seq) - len(subseq) + 1): + if torch.all(seq[i : i + len(subseq)] == subseq): + return i + return -1 + + +def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]): + """ + Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs. + + Args: + tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization. + text (List[str]): The list of texts to tokenize. + require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation. + + Returns: + Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids, + the start positions of loss spans, and the end positions of loss spans. + """ + input_ids = [] + loss_starts = [] + loss_ends = [] + for s, r in zip(text, require_loss): + tokenized = tokenizer(s, add_special_tokens=False)["input_ids"] + if r: + loss_starts.append(len(input_ids)) + loss_ends.append(len(input_ids) + len(tokenized)) + input_ids.extend(tokenized) + return input_ids, loss_starts, loss_ends + + +def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str): + # Seperate templated prompt into chunks by human/assistant's lines, prepare data for tokenize_and_concatenate + start_idx = 0 + chunks = [] + require_loss = [] + for line in messages: + first_occur = prompt.find(line["content"], start_idx) + if prompt[first_occur - 1] != " ": + chunks.append(prompt[start_idx:first_occur]) + chunks.append(prompt[first_occur : first_occur + len(line["content"])]) + else: + chunks.append(prompt[start_idx : first_occur - 1]) + chunks.append(prompt[first_occur - 1 : first_occur + len(line["content"])]) + start_idx = first_occur + len(line["content"]) + if line["role"].lower() == "assistant": + require_loss.append(False) + require_loss.append(True) + else: + require_loss.append(False) + require_loss.append(False) + chunks.append(prompt[start_idx:]) + require_loss.append(False) + return chunks, require_loss diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/ColossalChat/coati/experience_buffer/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/experience_buffer/__init__.py rename to applications/ColossalChat/coati/experience_buffer/__init__.py diff --git a/applications/Chat/coati/experience_buffer/base.py b/applications/ColossalChat/coati/experience_buffer/base.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/experience_buffer/base.py rename to applications/ColossalChat/coati/experience_buffer/base.py diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/ColossalChat/coati/experience_buffer/naive.py old mode 100644 new mode 100755 similarity index 86% rename from applications/Chat/coati/experience_buffer/naive.py rename to applications/ColossalChat/coati/experience_buffer/naive.py index d47b67dbe..b912df268 --- a/applications/Chat/coati/experience_buffer/naive.py +++ b/applications/ColossalChat/coati/experience_buffer/naive.py @@ -1,13 +1,16 @@ import random -import warnings from typing import List import torch from coati.experience_maker.base import Experience +from colossalai.logging import get_dist_logger + from .base import ExperienceBuffer from .utils import BufferItem, make_experience_batch, split_experience_batch +logger = get_dist_logger() + class NaiveExperienceBuffer(ExperienceBuffer): """Naive experience buffer class. It stores experience. @@ -35,7 +38,7 @@ class NaiveExperienceBuffer(ExperienceBuffer): if self.limit > 0: samples_to_remove = len(self.items) - self.limit if samples_to_remove > 0: - warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.") + logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.") self.items = self.items[samples_to_remove:] def clear(self) -> None: @@ -43,6 +46,12 @@ class NaiveExperienceBuffer(ExperienceBuffer): @torch.no_grad() def sample(self) -> Experience: + """ + Randomly samples experiences from the buffer. + + Returns: + A batch of sampled experiences. + """ items = random.sample(self.items, self.sample_batch_size) experience = make_experience_batch(items) if self.cpu_offload: diff --git a/applications/Chat/coati/experience_buffer/utils.py b/applications/ColossalChat/coati/experience_buffer/utils.py old mode 100644 new mode 100755 similarity index 94% rename from applications/Chat/coati/experience_buffer/utils.py rename to applications/ColossalChat/coati/experience_buffer/utils.py index baedbebd1..c4807d179 --- a/applications/Chat/coati/experience_buffer/utils.py +++ b/applications/ColossalChat/coati/experience_buffer/utils.py @@ -26,6 +26,7 @@ class BufferItem: action_log_probs: torch.Tensor values: torch.Tensor reward: torch.Tensor + kl: torch.Tensor advantages: torch.Tensor attention_mask: Optional[torch.LongTensor] action_mask: Optional[torch.BoolTensor] @@ -34,7 +35,7 @@ class BufferItem: def split_experience_batch(experience: Experience) -> List[BufferItem]: batch_size = experience.sequences.size(0) batch_kwargs = [{} for _ in range(batch_size)] - keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask") + keys = ("sequences", "action_log_probs", "values", "reward", "kl", "advantages", "attention_mask", "action_mask") for key in keys: value = getattr(experience, key) if isinstance(value, torch.Tensor): @@ -63,7 +64,7 @@ def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> to def make_experience_batch(items: List[BufferItem]) -> Experience: kwargs = {} to_pad_keys = set(("action_log_probs", "action_mask")) - keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask") + keys = ("sequences", "action_log_probs", "values", "reward", "kl", "advantages", "attention_mask", "action_mask") for key in keys: vals = [getattr(item, key) for item in items] if key in to_pad_keys: diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/ColossalChat/coati/experience_maker/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/experience_maker/__init__.py rename to applications/ColossalChat/coati/experience_maker/__init__.py diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/ColossalChat/coati/experience_maker/base.py old mode 100644 new mode 100755 similarity index 74% rename from applications/Chat/coati/experience_maker/base.py rename to applications/ColossalChat/coati/experience_maker/base.py index 0731f6e0f..55054c3a0 --- a/applications/Chat/coati/experience_maker/base.py +++ b/applications/ColossalChat/coati/experience_maker/base.py @@ -3,7 +3,8 @@ from dataclasses import dataclass from typing import Optional import torch -from coati.models.base import Actor, Critic, RewardModel +from coati.models import Critic, RewardModel +from transformers import PreTrainedModel @dataclass @@ -28,6 +29,7 @@ class Experience: action_log_probs: torch.Tensor values: torch.Tensor reward: torch.Tensor + kl: torch.Tensor advantages: torch.Tensor attention_mask: Optional[torch.LongTensor] action_mask: Optional[torch.BoolTensor] @@ -39,6 +41,7 @@ class Experience: self.values = self.values.to(device) self.reward = self.reward.to(device) self.advantages = self.advantages.to(device) + self.kl = self.kl.to(device) if self.attention_mask is not None: self.attention_mask = self.attention_mask.to(device) if self.action_mask is not None: @@ -50,6 +53,7 @@ class Experience: self.values = self.values.pin_memory() self.reward = self.reward.pin_memory() self.advantages = self.advantages.pin_memory() + self.kl = self.kl.pin_memory() if self.attention_mask is not None: self.attention_mask = self.attention_mask.pin_memory() if self.action_mask is not None: @@ -58,7 +62,13 @@ class Experience: class ExperienceMaker(ABC): - def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None: + """ + Base class for experience makers. + """ + + def __init__( + self, actor: PreTrainedModel, critic: Critic, reward_model: RewardModel, initial_model: PreTrainedModel + ) -> None: super().__init__() self.actor = actor self.critic = critic @@ -67,4 +77,14 @@ class ExperienceMaker(ABC): @abstractmethod def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience: - pass + """ + Abstract method to generate an experience. + + Args: + input_ids (torch.Tensor): The input tensor. + attention_mask (torch.Tensor): The attention mask tensor. + **generate_kwargs: Additional keyword arguments for generating the experience. + + Returns: + Experience: The generated experience. + """ diff --git a/applications/ColossalChat/coati/experience_maker/naive.py b/applications/ColossalChat/coati/experience_maker/naive.py new file mode 100755 index 000000000..945bb9557 --- /dev/null +++ b/applications/ColossalChat/coati/experience_maker/naive.py @@ -0,0 +1,180 @@ +""" +experience maker. +""" + +import torch +import torch.nn.functional as F +from coati.dataset.utils import find_first_occurrence_subsequence +from coati.models import Critic, RewardModel +from coati.models.generation import generate +from coati.models.utils import calc_action_log_probs, compute_reward +from transformers import PreTrainedModel, PreTrainedTokenizer + +from colossalai.logging import get_dist_logger + +from .base import Experience, ExperienceMaker + +logger = get_dist_logger() + +import torch.distributed as dist + + +def is_rank_0() -> bool: + return not dist.is_initialized() or dist.get_rank() == 0 + + +class NaiveExperienceMaker(ExperienceMaker): + """ + Naive experience maker. + """ + + def __init__( + self, + actor: PreTrainedModel, + critic: Critic, + reward_model: RewardModel, + initial_model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + kl_coef: float = 0.01, + gamma: float = 1.0, + lam: float = 0.95, + ) -> None: + super().__init__(actor, critic, reward_model, initial_model) + self.tokenizer = tokenizer + self.kl_coef = kl_coef + self.gamma = gamma + self.lam = lam + + @torch.no_grad() + def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor: + """ + Calculates the advantage values for each action based on the value and reward tensors. + + Args: + value (torch.Tensor): Tensor containing the predicted values from critic. + reward (torch.Tensor): reward of the shape [B, len]. + num_actions (int): Number of actions. + + Returns: + torch.Tensor: Tensor containing the calculated advantages for each action. + """ + lastgaelam = 0 + advantages_reversed = [] + for t in reversed(range(num_actions)): + nextvalues = value[:, t + 1] if t < num_actions - 1 else 0.0 + delta = reward[:, t] + self.gamma * nextvalues - value[:, t] + lastgaelam = delta + self.gamma * self.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + return advantages + + @torch.no_grad() + def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience: + """ + Generates an experience using the given input_ids and attention_mask. + + Args: + input_ids (torch.Tensor): The input tensor containing the tokenized input sequence. + attention_mask (torch.Tensor): The attention mask tensor indicating which tokens to attend to. + **generate_kwargs: Additional keyword arguments for the generation process. + + Returns: + Experience: The generated experience object. + + """ + self.actor.eval() + self.critic.eval() + self.initial_model.eval() + self.reward_model.eval() + pad_token_id = self.tokenizer.pad_token_id + + stop_token_ids = generate_kwargs.get("stop_token_ids", None) + torch.manual_seed(41) # for tp, gurantee the same input for reward model + + sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs) + + # Pad to max length + sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id) + sequence_length = sequences.size(1) + + # Calculate auxiliary tensors + attention_mask = None + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + + input_len = input_ids.size(1) + if stop_token_ids is None: + # End the sequence with eos token + eos_token_id = self.tokenizer.eos_token_id + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # Left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + else: + # stop_token_ids are given, generation ends with stop_token_ids + action_mask = torch.ones_like(sequences, dtype=torch.bool) + for i in range(sequences.size(0)): + stop_index = find_first_occurrence_subsequence( + sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device) + ) + if stop_index == -1: + # Sequence does not contain stop_token_ids, this should never happen BTW + logger.warning( + "Generated sequence does not contain stop_token_ids. Please check your chat template config" + ) + else: + # Keep stop tokens + stop_index = input_len + stop_index + action_mask[i, stop_index + len(stop_token_ids) :] = False + + generation_end_index = (action_mask == True).sum(dim=-1) - 1 + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + action_mask = action_mask[:, -(sequences.size(1) - input_len) :] + num_actions = action_mask.size(1) + + actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"] + action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) + + base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"] + + base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) + + # Convert to right padding for the reward model and the critic model + input_ids_rm = torch.zeros_like(sequences, device=sequences.device) + attention_mask_rm = torch.zeros_like(sequences, device=sequences.device) + for i in range(sequences.size(0)): + sequence = sequences[i] + bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0] + eos_index = generation_end_index[i] + sequence_to_pad = sequence[bos_index:eos_index] + sequence_padded = F.pad( + sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id + ) + input_ids_rm[i] = sequence_padded + if sequence_length - sequence_to_pad.size(0) > 0: + attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1 + else: + attention_mask_rm[i, :] = 1 + attention_mask_rm = attention_mask_rm.to(dtype=torch.bool) + + r = self.reward_model( + input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device), + attention_mask=attention_mask_rm.to(device=sequences.device), + ) + + value = self.critic( + input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device), + attention_mask=attention_mask_rm.to(device=sequences.device), + ) + reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) + value = value[:, -num_actions:] * action_mask + advantages = self.calculate_advantage(value, reward, num_actions) + + advantages = advantages.detach() + value = value.detach() + r = r.detach() + + return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask) diff --git a/applications/ColossalChat/coati/models/__init__.py b/applications/ColossalChat/coati/models/__init__.py new file mode 100755 index 000000000..14073207f --- /dev/null +++ b/applications/ColossalChat/coati/models/__init__.py @@ -0,0 +1,24 @@ +from .base import BaseModel +from .critic import Critic +from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn +from .lora import convert_to_lora_module +from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss +from .reward_model import RewardModel +from .utils import disable_dropout + +__all__ = [ + "BaseModel", + "Critic", + "RewardModel", + "PolicyLoss", + "ValueLoss", + "LogSigLoss", + "LogExpLoss", + "convert_to_lora_module", + "DpoLoss", + "generate", + "generate_streaming", + "disable_dropout", + "update_model_kwargs_fn", + "prepare_inputs_fn", +] diff --git a/applications/ColossalChat/coati/models/base.py b/applications/ColossalChat/coati/models/base.py new file mode 100755 index 000000000..fcea9414b --- /dev/null +++ b/applications/ColossalChat/coati/models/base.py @@ -0,0 +1,58 @@ +""" +Base class for critic and reward model +""" + +from typing import Optional + +import torch +import torch.nn as nn +from transformers import AutoModel, PretrainedConfig + + +class BaseModel(nn.Module): + """ + Actor model base class. + + Args: + pretrained (str): path to pretrained model. + config (PretrainedConfig): PretrainedConfig used to initiate the base model. + **kwargs: all other kwargs as in AutoModel.from_pretrained + """ + + def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None: + super().__init__() + if pretrained is not None: + if config is not None: + # initialize with config and load weights from pretrained + self.model = AutoModel.from_pretrained(pretrained, config=config, **kwargs) + else: + # initialize with pretrained + self.model = AutoModel.from_pretrained(pretrained, **kwargs) + elif config is not None: + # initialize with config + self.model = AutoModel.from_config(config, **kwargs) + else: + raise ValueError("Either pretrained or config must be provided.") + + self.config = self.model.config + # create dummy input to get the size of the last hidden state + if "use_flash_attention_2" in kwargs: + self.model = self.model.cuda() + dummy_input = torch.zeros((1, 1), dtype=torch.long).to(self.model.device) + out = self.model(dummy_input) + self.last_hidden_state_size = out.last_hidden_state.shape[-1] + self.model = self.model.cpu() + # print("self.last_hidden_state_size: ",self.last_hidden_state_size) + + def resize_token_embeddings(self, *args, **kwargs): + """ + Resize the token embeddings of the model. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The resized token embeddings. + """ + return self.model.resize_token_embeddings(*args, **kwargs) diff --git a/applications/ColossalChat/coati/models/critic.py b/applications/ColossalChat/coati/models/critic.py new file mode 100755 index 000000000..80340d9bd --- /dev/null +++ b/applications/ColossalChat/coati/models/critic.py @@ -0,0 +1,34 @@ +""" +Critic model +""" + +from typing import Optional + +import torch +import torch.nn as nn +from coati.models import BaseModel +from transformers import PretrainedConfig + + +class Critic(BaseModel): + """ + Critic model class. + + Args: + pretrained (str): path to pretrained model. + config (PretrainedConfig): PretrainedConfig used to initiate the base model. + """ + + def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None: + super().__init__(pretrained=pretrained, config=config, **kwargs) + # et last hidden state size with dummy input + self.value_head = nn.Linear(self.last_hidden_state_size, 1) + + def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + outputs = self.model(input_ids, attention_mask=attention_mask) + last_hidden_states = outputs["last_hidden_state"] + sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), :].type( + self.value_head.weight.dtype + ) + values = self.value_head(sequence_hidden_states).squeeze(-1) # ensure shape is (B, sequence length) + return values diff --git a/applications/ColossalChat/coati/models/generation.py b/applications/ColossalChat/coati/models/generation.py new file mode 100755 index 000000000..b671ef124 --- /dev/null +++ b/applications/ColossalChat/coati/models/generation.py @@ -0,0 +1,428 @@ +from typing import Any, Callable, List, Optional + +import torch +import torch.distributed as dist +from transformers import PreTrainedTokenizer + +try: + from transformers.generation_logits_process import ( + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) +except ImportError: + from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper + + +def _prepare_logits_processor( + top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None +) -> LogitsProcessorList: + """ + Prepare the logits processor list based on the given parameters. + + Args: + top_k (Optional[int]): The number of highest probability logits to keep for each token. + top_p (Optional[float]): The cumulative probability threshold for selecting tokens. + temperature (Optional[float]): The temperature value to apply to the logits. + + Returns: + LogitsProcessorList: The list of logits processors. + + """ + processor_list = LogitsProcessorList() + if temperature is not None and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + processor_list.append(TopKLogitsWarper(top_k)) + if top_p is not None and top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + return processor_list + + +def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: + """ + Check if the sequence generation is finished. + + Args: + unfinished_sequences (torch.Tensor): Tensor indicating the unfinished sequences. + + Returns: + bool: True if all sequences are finished, False otherwise. + """ + if dist.is_initialized() and dist.get_world_size() > 1: + # consider DP + unfinished_sequences = unfinished_sequences.clone() + dist.all_reduce(unfinished_sequences) + return unfinished_sequences.max() == 0 + + +def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict: + """ + Update the model keyword arguments based on the outputs and new mask. + + Args: + outputs (dict): The outputs from the model. + new_mask: The new attention mask. + **model_kwargs: Additional model keyword arguments. + + Returns: + dict: The updated model keyword arguments. + """ + + if "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs["past_key_values"] + else: + model_kwargs["past_key_values"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat([attention_mask, new_mask], dim=-1) + + return model_kwargs + + +def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict: + model_kwargs["input_ids"] = input_ids + return model_kwargs + + +def _sample( + model: Any, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = True, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + stop_token_ids: Optional[List[int]] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + max_new_tokens: int = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + stream_interval: int = 2, + **model_kwargs, +) -> torch.Tensor: + """ + Generates new tokens using the given model and input_ids. + + Args: + model (Any): The model used for token generation. + input_ids (torch.Tensor): The input tensor containing the initial tokens. + max_length (int): The maximum length of the generated tokens. + early_stopping (bool, optional): Whether to stop generating tokens early if all sequences are finished. Defaults to True. + eos_token_id (int, optional): The ID of the end-of-sequence token. Defaults to None. + pad_token_id (int, optional): The ID of the padding token. Defaults to None. + stop_token_ids (List[int], optional): A list of token IDs that, if encountered, will stop the generation process. Defaults to None. + top_k (int, optional): The number of top-k tokens to consider during sampling. Defaults to None. + top_p (float, optional): The cumulative probability threshold for top-p sampling. Defaults to None. + temperature (float, optional): The temperature value for token sampling. Defaults to None. + max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to None. + prepare_inputs_fn (Callable[[torch.Tensor, Any], dict], optional): A function to prepare the model inputs. Defaults to None. + update_model_kwargs_fn (Callable[[dict, Any], dict], optional): A function to update the model kwargs. Defaults to None. + stream_interval (int, optional): The interval for streaming generation. Defaults to 2. + **model_kwargs: Additional keyword arguments for the model. + + Returns: + torch.Tensor: The tensor containing the generated tokens. + """ + context_length = input_ids.size(1) + if max_new_tokens is None: + max_new_tokens = max_length - context_length + if context_length + max_new_tokens > max_length or max_new_tokens == 0: + return input_ids + + logits_processor = _prepare_logits_processor(top_k, top_p, temperature) + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + past = None + for i in range(context_length, context_length + max_new_tokens): + # Calculate attention mask + if "attention_mask" not in model_kwargs: + model_kwargs["attention_mask"] = input_ids.ne(pad_token_id) + model_inputs = ( + prepare_inputs_fn(input_ids, past=past, **model_kwargs) + if prepare_inputs_fn is not None + else {"input_ids": input_ids, "attention_mask": input_ids.ne(pad_token_id)} + ) + outputs = model(**model_inputs) + + if "past_key_values" in outputs: + past = outputs.past_key_values + elif "mems" in outputs: + past = outputs.mems + + # NOTE: this is correct only in left padding mode + next_token_logits = outputs["logits"][:, -1, :] + next_token_logits = logits_processor(input_ids, next_token_logits) + + # Sample + probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # Finished sentences should have their next token be a padding token + if eos_token_id is not None: + assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # Update generated ids, model inputs for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + if update_model_kwargs_fn is not None: + model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) + + # If eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + if stop_token_ids is not None: + # If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished. + tokens_to_check = input_ids[:, -len(stop_token_ids) :] + unfinished_sequences = unfinished_sequences.mul( + torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long() + ) + + # Stop when each sentence is finished if early_stopping=True + if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1: + if i == context_length + max_new_tokens - 1: + # Force to end with stop token ids + input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = ( + torch.LongTensor(stop_token_ids).to(input_ids.device).long() + ) + return input_ids + + +@torch.inference_mode() +def generate( + model: Any, + input_ids: torch.Tensor, + tokenizer: PreTrainedTokenizer, + max_length: int, + num_beams: int = 1, + do_sample: bool = True, + early_stopping: bool = True, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> torch.Tensor: + """Generate token sequence. The returned sequence is input_ids + generated_tokens. + + Args: + model (nn.Module): model + input_ids (torch.Tensor): input sequence + max_length (int): max length of the returned sequence + num_beams (int, optional): number of beams. Defaults to 1. + do_sample (bool, optional): whether to do sample. Defaults to True. + early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. + top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. + top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. + temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. + prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. + update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. + """ + assert tokenizer.padding_side == "left", "Current generation only supports left padding." + is_greedy_gen_mode = (num_beams == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and do_sample is False + if is_greedy_gen_mode: + raise NotImplementedError + elif is_sample_gen_mode: + # Run sample + res = _sample( + model, + input_ids, + max_length, + early_stopping=early_stopping, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + top_k=top_k, + top_p=top_p, + temperature=temperature, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + **model_kwargs, + ) + return res + elif is_beam_gen_mode: + raise NotImplementedError + else: + raise ValueError("Unsupported generation mode") + + +def _sample_streaming( + model: Any, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + stop_token_ids: Optional[List[int]] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + max_new_tokens: int = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + stream_interval: int = 2, + **model_kwargs, +) -> torch.Tensor: + """ + Generates new tokens using a streaming approach. + + Args: + model (Any): The model used for token generation. + input_ids (torch.Tensor): The input tensor containing the initial tokens. + max_length (int): The maximum length of the generated sequence. + early_stopping (bool, optional): Whether to stop generating tokens for a sequence if it is finished. Defaults to False. + eos_token_id (int, optional): The ID of the end-of-sequence token. Defaults to None. + pad_token_id (int, optional): The ID of the padding token. Defaults to None. + stop_token_ids (List[int], optional): A list of token IDs that, if encountered, will mark the sequence as finished. Defaults to None. + top_k (int, optional): The number of top-k tokens to consider during sampling. Defaults to None. + top_p (float, optional): The cumulative probability threshold for top-p sampling. Defaults to None. + temperature (float, optional): The temperature value for sampling. Defaults to None. + max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to None. + prepare_inputs_fn (Callable[[torch.Tensor, Any], dict], optional): A function to prepare the model inputs. Defaults to None. + update_model_kwargs_fn (Callable[[dict, Any], dict], optional): A function to update the model keyword arguments. Defaults to None. + stream_interval (int, optional): The interval at which to yield the generated tokens. Defaults to 2. + **model_kwargs: Additional keyword arguments to be passed to the model. + + Yields: + torch.Tensor: The generated tokens at each step. + + Returns: + torch.Tensor: The final generated tokens. + """ + + context_length = input_ids.size(1) + if max_new_tokens is None: + max_new_tokens = max_length - context_length + if context_length + max_new_tokens > max_length or max_new_tokens == 0: + return input_ids + + logits_processor = _prepare_logits_processor(top_k, top_p, temperature) + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + past = None + for i in range(context_length, context_length + max_new_tokens): + # calculate attention mask + if "attention_mask" not in model_kwargs: + model_kwargs["attention_mask"] = input_ids.ne(pad_token_id) + model_inputs = ( + prepare_inputs_fn(input_ids, past=past, **model_kwargs) + if prepare_inputs_fn is not None + else {"input_ids": input_ids, "attention_mask": input_ids.ne(pad_token_id)} + ) + outputs = model(**model_inputs) + if "past_key_values" in outputs: + past = outputs.past_key_values + elif "mems" in outputs: + past = outputs.mems + + # NOTE: this is correct only in left padding mode + next_token_logits = outputs["logits"][:, -1, :] + next_token_logits = logits_processor(input_ids, next_token_logits) + # sample + probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + if update_model_kwargs_fn is not None: + model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + if stop_token_ids is not None: + # If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished. + tokens_to_check = input_ids[:, -len(stop_token_ids) :] + unfinished_sequences = unfinished_sequences.mul( + torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long() + ) + + # Stop when each sentence is finished if early_stopping=True + if ( + (early_stopping and _is_sequence_finished(unfinished_sequences)) + or (i - context_length) % stream_interval == 0 + or i == context_length + max_new_tokens - 1 + ): + yield input_ids + if early_stopping and _is_sequence_finished(unfinished_sequences): + break + + +@torch.inference_mode() +def generate_streaming( + model: Any, + input_ids: torch.Tensor, + tokenizer: PreTrainedTokenizer, + max_length: int, + num_beams: int = 1, + do_sample: bool = True, + early_stopping: bool = False, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +): + """Generate token sequence. The returned sequence is input_ids + generated_tokens. + + Args: + model (nn.Module): model + input_ids (torch.Tensor): input sequence + max_length (int): max length of the returned sequence + num_beams (int, optional): number of beams. Defaults to 1. + do_sample (bool, optional): whether to do sample. Defaults to True. + early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False. + top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. + top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None. + temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None. + prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. + update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. + """ + assert tokenizer.padding_side == "left", "Current generation only supports left padding." + is_greedy_gen_mode = (num_beams == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and do_sample is False + if is_greedy_gen_mode: + # run greedy search + raise NotImplementedError + elif is_sample_gen_mode: + # run sample + for res in _sample_streaming( + model, + input_ids, + max_length, + early_stopping=early_stopping, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + top_k=top_k, + top_p=top_p, + temperature=temperature, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + **model_kwargs, + ): + yield res + elif is_beam_gen_mode: + raise NotImplementedError + else: + raise ValueError("Unsupported generation mode") diff --git a/applications/Chat/coati/models/lora.py b/applications/ColossalChat/coati/models/lora.py old mode 100644 new mode 100755 similarity index 85% rename from applications/Chat/coati/models/lora.py rename to applications/ColossalChat/coati/models/lora.py index e9bd7b2ed..9553b00ff --- a/applications/Chat/coati/models/lora.py +++ b/applications/ColossalChat/coati/models/lora.py @@ -1,3 +1,7 @@ +""" +LORA utils +""" + import dataclasses import math import warnings @@ -8,6 +12,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + @dataclasses.dataclass class LoRAManager: @@ -58,6 +66,10 @@ class LoraLinear(lora.LoRALayer, nn.Module): nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): + """ + This function runs when model.train() is invoked. It is used to prepare the linear layer for training + """ + def T(w): return w.T if self.fan_in_fan_out else w @@ -101,6 +113,16 @@ class LoraLinear(lora.LoRALayer, nn.Module): def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: + """ + Wraps a linear layer with LoRA functionality. + + Args: + linear (nn.Linear): The linear layer to be wrapped. + lora_rank (int): The rank of the LoRA decomposition. + + Returns: + LoraLinear: The wrapped linear layer with LoRA functionality. + """ assert ( lora_rank <= linear.in_features ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" @@ -109,6 +131,16 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: + """ + Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form. + + Args: + module (nn.Module): The module to convert to LoRA form. + lora_rank (int): The rank of the LoRA approximation. + + Returns: + None + """ for name, child in module.named_children(): if isinstance(child, nn.Linear): setattr(module, name, _lora_linear_wrapper(child, lora_rank)) @@ -131,23 +163,3 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s _convert_to_lora_recursively(module, lora_rank) lora.mark_only_lora_as_trainable(module, lora_train_bias) return module - - -class LoRAModule(nn.Module): - """A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`. - This class will convert all torch.nn.Linear layer to LoraLinear layer. - - Args: - lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0. - lora_train_bias (str, optional): Whether LoRA train biases. - 'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers. - Defaults to 'none'. - """ - - def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None: - super().__init__() - self.lora_rank = lora_rank - self.lora_train_bias = lora_train_bias - - def convert_to_lora(self) -> None: - convert_to_lora_module(self, self.lora_rank, self.lora_train_bias) diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py new file mode 100755 index 000000000..aaef447a4 --- /dev/null +++ b/applications/ColossalChat/coati/models/loss.py @@ -0,0 +1,169 @@ +""" +loss functions +""" +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from .utils import masked_mean + + +class GPTLMLoss(nn.Module): + """ + GPT Language Model Loss + """ + + def __init__(self): + super().__init__() + # NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py + self.loss = nn.CrossEntropyLoss() + + def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +class PolicyLoss(nn.Module): + """ + Policy Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None: + super().__init__() + self.clip_eps = clip_eps + self.skip_threshold = skip_threshold + + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + skip = False + ratio_ = ((log_probs - old_log_probs) * action_mask).exp() + + # note that if dropout is disabled (recommanded), ratio will always be 1. + if ratio_.mean() > self.skip_threshold: + skip = True + + ratio = ratio_.clamp(0.0, 10.0) + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + loss = -torch.min(surr1, surr2) + loss = masked_mean(loss, action_mask) + loss = loss.mean() + return loss, skip, ratio_.max() + + +class ValueLoss(nn.Module): + """ + Value Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward( + self, + values: torch.Tensor, + old_values: torch.Tensor, + advantage: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + returns = advantage + old_values + values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) + surr1 = (values_clipped - returns) ** 2 + surr2 = (values - returns) ** 2 + loss = torch.max(surr1, surr2) / torch.sum(action_mask) + loss = torch.sum(loss * action_mask) + return 0.5 * loss + + +class DpoLoss(nn.Module): + """ + Dpo loss + Details: https://arxiv.org/pdf/2305.18290.pdf + """ + + def __init__(self, beta: float = 0.1): + super().__init__() + self.beta = beta + + def forward( + self, + logprob_actor_chosen: torch.Tensor, + logprob_actor_reject: torch.Tensor, + logprob_ref_chosen: torch.Tensor, + logprob_ref_reject: torch.Tensor, + chosen_mask: torch.Tensor, + reject_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities. + + # adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328 + + Args: + logprob_actor_chosen: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + logprob_actor_chosen = logprob_actor_chosen * chosen_mask + logprob_actor_reject = logprob_actor_reject * reject_mask + if logprob_ref_chosen is not None and logprob_ref_reject is not None: + logprob_ref_chosen = logprob_ref_chosen * chosen_mask + logprob_ref_reject = logprob_ref_reject * reject_mask + if len(logprob_ref_chosen.shape) == 2: + ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1) + else: + ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze() + else: + # If no reference model is provided + ref_logratios = 0.0 + + pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1) + logits = pi_logratios - ref_logratios + losses = -torch.nn.functional.logsigmoid(self.beta * logits) + + # Calculate rewards for logging + if logprob_ref_chosen is not None: + chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach() + else: + chosen_rewards = self.beta * logprob_actor_chosen.sum(-1).detach() + if logprob_ref_reject is not None: + rejected_rewards = self.beta * (logprob_actor_reject.sum(-1) - logprob_ref_reject.sum(-1)).detach() + else: + rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach() + + return losses, chosen_rewards, rejected_rewards + + +class LogSigLoss(nn.Module): + """ + Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2203.02155 + """ + + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: + return -torch.nn.functional.logsigmoid(chosen_reward - reject_reward).mean() + + +class LogExpLoss(nn.Module): + """ + Pairwise Loss for Reward Model + Details: https://arxiv.org/abs/2204.05862 + """ + + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: + loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() + return loss diff --git a/applications/ColossalChat/coati/models/reward_model.py b/applications/ColossalChat/coati/models/reward_model.py new file mode 100755 index 000000000..18c5eca41 --- /dev/null +++ b/applications/ColossalChat/coati/models/reward_model.py @@ -0,0 +1,38 @@ +""" +reward model +""" +from typing import Optional + +import torch +import torch.nn as nn +from coati.models import BaseModel +from transformers import PretrainedConfig + + +class RewardModel(BaseModel): + """ + Reward model class. + + Args: + pretrained str: huggingface or local model path + config: PretrainedConfig object + **kwargs: all other kwargs as in AutoModel.from_pretrained + """ + + def __init__(self, pretrained: str = None, config: Optional[PretrainedConfig] = None, **kwargs) -> None: + super().__init__(pretrained=pretrained, config=config, **kwargs) + self.value_head = nn.Linear(self.last_hidden_state_size, 1) + self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1)) + + def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + outputs = self.model(input_ids, attention_mask=attention_mask) + + last_hidden_states = outputs["last_hidden_state"] + sequence_lengths = torch.max(attention_mask * torch.arange(input_ids.size(1), device=input_ids.device), dim=1)[ + 0 + ] + sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths].type( + self.value_head.weight.dtype + ) + values = self.value_head(sequence_hidden_states).squeeze(-1) # Ensure shape is (B,) + return values diff --git a/applications/ColossalChat/coati/models/utils.py b/applications/ColossalChat/coati/models/utils.py new file mode 100755 index 000000000..ce672534c --- /dev/null +++ b/applications/ColossalChat/coati/models/utils.py @@ -0,0 +1,137 @@ +import json +import os +from typing import Any, Dict, Optional, Union + +import torch +import torch.nn.functional as F + + +def get_model_numel(model: torch.nn.Module) -> int: + return sum(p.numel() for p in model.parameters()) + + +def compute_reward( + r: Union[torch.Tensor, float], + kl_coef: float, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + reward_eps=5, +) -> torch.Tensor: + """ + Args: + log_probs: [batch_size, response_length] + log_probs_base: [batch_size, response_length] + action_mask: [batch_size, response_length] + r: float + Returns: + reward: [batch_size, response_length] + """ + log_ratio = log_probs - log_probs_base # address numerical instability issue + kl = -kl_coef * log_ratio * action_mask + reward = kl + r_clip = torch.clamp(r, -reward_eps, reward_eps) + for i in range(action_mask.size(0)): + assert action_mask[i].sum() > 0 + reward[i, : action_mask[i].sum()] += r_clip[i] + reward[i, action_mask[i].sum() :] *= 0 + return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask + + +def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Compute the log probabilities from logits for the given labels. + + Args: + logits (torch.Tensor): The input logits. + labels (torch.Tensor): The target labels. + + Returns: + torch.Tensor: The log probabilities corresponding to the labels. + """ + log_probs = F.log_softmax(logits, dim=-1) + log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) + return log_probs_labels.squeeze(-1) + + +def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: + """Calculate action log probs. + + Args: + output (torch.Tensor): Output tensor of Actor.forward.logits. + sequences (torch.LongTensor): Input sequences. + num_actions (int): Number of actions. + + Returns: + torch.Tensor: Action log probs. + """ + log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + +def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Compute the masked mean of a tensor along a specified dimension. + + Args: + tensor (torch.Tensor): The input tensor. + mask (torch.Tensor): The mask tensor with the same shape as the input tensor. + dim (int, optional): The dimension along which to compute the mean. Default is 1. + + Returns: + torch.Tensor: The masked mean tensor. + + """ + tensor = tensor * mask + tensor = tensor.sum(dim=dim) + mask_sum = mask.sum(dim=dim) + mean = tensor / (mask_sum + 1e-8) + return mean + + +def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor: + """ + Calculate the masked log probabilities for a given sequence of logits. + + Args: + logits (torch.Tensor): The input logits tensor of shape (batch_size, sequence_length, vocab_size). + sequences (torch.LongTensor): The input sequence tensor of shape (batch_size, sequence_length). + mask (torch.Tensor): The mask tensor of shape (batch_size, sequence_length). + + Returns: + torch.Tensor: The masked log probabilities tensor of shape (batch_size, sequence_length - 1). + """ + # logits are probabilities of the next token, so we shift them to the left by one + log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs * mask + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def disable_dropout(model: torch.nn.Module): + """ + Disables dropout in a PyTorch model. This is used in PPO Training + + Args: + model (torch.nn.Module): The PyTorch model. + + Returns: + None + """ + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0.0 diff --git a/applications/Chat/coati/quant/__init__.py b/applications/ColossalChat/coati/quant/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/quant/__init__.py rename to applications/ColossalChat/coati/quant/__init__.py diff --git a/applications/Chat/coati/quant/llama_gptq/__init__.py b/applications/ColossalChat/coati/quant/llama_gptq/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/quant/llama_gptq/__init__.py rename to applications/ColossalChat/coati/quant/llama_gptq/__init__.py diff --git a/applications/Chat/coati/quant/llama_gptq/loader.py b/applications/ColossalChat/coati/quant/llama_gptq/loader.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/quant/llama_gptq/loader.py rename to applications/ColossalChat/coati/quant/llama_gptq/loader.py diff --git a/applications/Chat/coati/quant/llama_gptq/model_utils.py b/applications/ColossalChat/coati/quant/llama_gptq/model_utils.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/quant/llama_gptq/model_utils.py rename to applications/ColossalChat/coati/quant/llama_gptq/model_utils.py diff --git a/applications/Chat/coati/quant/llama_gptq/quant.py b/applications/ColossalChat/coati/quant/llama_gptq/quant.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/quant/llama_gptq/quant.py rename to applications/ColossalChat/coati/quant/llama_gptq/quant.py diff --git a/applications/Chat/coati/quant/utils.py b/applications/ColossalChat/coati/quant/utils.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/quant/utils.py rename to applications/ColossalChat/coati/quant/utils.py diff --git a/applications/Chat/coati/ray/README.md b/applications/ColossalChat/coati/ray/README.md old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/README.md rename to applications/ColossalChat/coati/ray/README.md diff --git a/applications/Chat/coati/ray/__init__.py b/applications/ColossalChat/coati/ray/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/__init__.py rename to applications/ColossalChat/coati/ray/__init__.py diff --git a/applications/Chat/coati/ray/callbacks/__init__.py b/applications/ColossalChat/coati/ray/callbacks/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/callbacks/__init__.py rename to applications/ColossalChat/coati/ray/callbacks/__init__.py diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/ColossalChat/coati/ray/callbacks/base.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/callbacks/base.py rename to applications/ColossalChat/coati/ray/callbacks/base.py diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/ColossalChat/coati/ray/callbacks/performance_evaluator.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/callbacks/performance_evaluator.py rename to applications/ColossalChat/coati/ray/callbacks/performance_evaluator.py diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/ColossalChat/coati/ray/detached_replay_buffer.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/detached_replay_buffer.py rename to applications/ColossalChat/coati/ray/detached_replay_buffer.py diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/ColossalChat/coati/ray/detached_trainer_base.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/detached_trainer_base.py rename to applications/ColossalChat/coati/ray/detached_trainer_base.py diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/ColossalChat/coati/ray/detached_trainer_ppo.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/detached_trainer_ppo.py rename to applications/ColossalChat/coati/ray/detached_trainer_ppo.py diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/ColossalChat/coati/ray/experience_maker_holder.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/experience_maker_holder.py rename to applications/ColossalChat/coati/ray/experience_maker_holder.py diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/ColossalChat/coati/ray/lora_constructor.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/coati/ray/lora_constructor.py rename to applications/ColossalChat/coati/ray/lora_constructor.py diff --git a/applications/Chat/coati/ray/utils.py b/applications/ColossalChat/coati/ray/utils.py old mode 100644 new mode 100755 similarity index 97% rename from applications/Chat/coati/ray/utils.py rename to applications/ColossalChat/coati/ray/utils.py index b88140c0e..4882f00b7 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/ColossalChat/coati/ray/utils.py @@ -75,7 +75,9 @@ def get_strategy_from_args(strategy: str): elif strategy == "colossalai_zero2": strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda") elif strategy == "colossalai_gemini_cpu": - strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) + strategy_ = GeminiStrategy( + placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5 + ) elif strategy == "colossalai_zero2_cpu": strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: diff --git a/applications/ColossalChat/coati/trainer/__init__.py b/applications/ColossalChat/coati/trainer/__init__.py new file mode 100755 index 000000000..2eff8ca76 --- /dev/null +++ b/applications/ColossalChat/coati/trainer/__init__.py @@ -0,0 +1,7 @@ +from .base import OLTrainer, SLTrainer +from .dpo import DPOTrainer +from .ppo import PPOTrainer +from .rm import RewardModelTrainer +from .sft import SFTTrainer + +__all__ = ["SLTrainer", "OLTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer", "DPOTrainer"] diff --git a/applications/Chat/coati/trainer/base.py b/applications/ColossalChat/coati/trainer/base.py old mode 100644 new mode 100755 similarity index 80% rename from applications/Chat/coati/trainer/base.py rename to applications/ColossalChat/coati/trainer/base.py index 0a41d450d..63c903a51 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/ColossalChat/coati/trainer/base.py @@ -1,6 +1,14 @@ +""" +Base trainers for online and offline training + SLTrainer: supervised learning trainer + pretrain, sft, dpo, reward model training + OLTrainer: online learning trainer + rlhf-ppo +""" + from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import List +from typing import Callable, List import torch.nn as nn import tqdm @@ -8,8 +16,8 @@ from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import Experience from torch.optim import Optimizer -from .callbacks import Callback -from .strategies import Strategy +from colossalai.booster import Booster + from .utils import is_rank_0 @@ -26,16 +34,18 @@ class SLTrainer(ABC): def __init__( self, - strategy: Strategy, + booster: Booster, max_epochs: int, model: nn.Module, optimizer: Optimizer, + start_epoch: int = 0, ) -> None: super().__init__() - self.strategy = strategy + self.booster = booster self.max_epochs = max_epochs self.model = model self.optimizer = optimizer + self.start_epoch = start_epoch @abstractmethod def _train(self, epoch): @@ -45,19 +55,20 @@ class SLTrainer(ABC): def _eval(self, epoch): raise NotImplementedError() + @abstractmethod def _before_fit(self): raise NotImplementedError() def fit(self, *args, **kwargs): self._before_fit(*args, **kwargs) - for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()): + for epoch in tqdm.trange(self.start_epoch, self.max_epochs, desc="Epochs", disable=not is_rank_0()): self._train(epoch) self._eval(epoch) -class OnPolicyTrainer(ABC): +class OLTrainer(ABC): """ - Base class for on-policy rl trainers, e.g. PPO. + Base class for online learning trainers, e.g. PPO. Args: strategy (Strategy):the strategy to use for training @@ -69,14 +80,16 @@ class OnPolicyTrainer(ABC): def __init__( self, - strategy: Strategy, + actor_booster: Booster, + critic_booster: Booster, data_buffer: NaiveExperienceBuffer, sample_buffer: bool, dataloader_pin_memory: bool, - callbacks: List[Callback] = [], + callbacks: List[Callable] = [], ) -> None: super().__init__() - self.strategy = strategy + self.actor_booster = actor_booster + self.critic_booster = critic_booster self.data_buffer = data_buffer self.sample_buffer = sample_buffer self.dataloader_pin_memory = dataloader_pin_memory @@ -141,6 +154,20 @@ class OnPolicyTrainer(ABC): """ raise NotImplementedError() + @abstractmethod + def _setup_update_phrase_dataload(self): + """ + Implement this method to setup dataloader for update phase. + """ + raise NotImplementedError() + + @abstractmethod + def _save_checkpoint(self, episode: int = 0): + """ + Implement this method to save checkpoint. + """ + raise NotImplementedError() + def _collect_phase(self, collect_step: int): self._on_make_experience_start() experience = self._make_experience(collect_step) @@ -178,11 +205,10 @@ class OnPolicyTrainer(ABC): for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()): self._collect_phase(collect_step) if not self.sample_buffer: - # HACK(cwher): according to the design of boost API, dataloader should also be boosted, - # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. - # I only call strategy.setup_dataloader() to setup dataloader. - self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory) + self._setup_update_phrase_dataload() for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()): self._update_phase(update_step) # NOTE: this is for on-policy algorithms self.data_buffer.clear() + if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0: + self._save_checkpoint(episode + 1) diff --git a/applications/ColossalChat/coati/trainer/callbacks/__init__.py b/applications/ColossalChat/coati/trainer/callbacks/__init__.py new file mode 100644 index 000000000..a76548507 --- /dev/null +++ b/applications/ColossalChat/coati/trainer/callbacks/__init__.py @@ -0,0 +1,4 @@ +from .base import Callback +from .performance_evaluator import PerformanceEvaluator + +__all__ = ["Callback", "PerformanceEvaluator"] diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/ColossalChat/coati/trainer/callbacks/base.py similarity index 100% rename from applications/Chat/coati/trainer/callbacks/base.py rename to applications/ColossalChat/coati/trainer/callbacks/base.py diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/ColossalChat/coati/trainer/callbacks/performance_evaluator.py similarity index 92% rename from applications/Chat/coati/trainer/callbacks/performance_evaluator.py rename to applications/ColossalChat/coati/trainer/callbacks/performance_evaluator.py index b286c766c..86384e5e3 100644 --- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py +++ b/applications/ColossalChat/coati/trainer/callbacks/performance_evaluator.py @@ -14,9 +14,11 @@ def get_world_size() -> int: return 1 -def print_rank_0(*args, **kwargs) -> None: +def save_eval_result_rank_0(s: str, save_path: str, **kwargs) -> None: if not dist.is_initialized() or dist.get_rank() == 0: - print(*args, **kwargs) + with open(save_path, "a+") as f: + train_config = "; ".join([str(kwargs[key]) for key in kwargs]) + f.write(train_config + "\n" + s + "\n") def divide(x: float, y: float) -> float: @@ -74,6 +76,8 @@ class PerformanceEvaluator(Callback): reward_model_num_params: int, enable_grad_checkpoint: bool = False, ignore_episodes: int = 0, + train_config: Optional[dict] = None, + save_path: Optional[str] = None, ) -> None: super().__init__() self.world_size = get_world_size() @@ -92,6 +96,8 @@ class PerformanceEvaluator(Callback): self.make_experience_flop: int = 0 self.learn_num_samples: int = 0 self.learn_flop: int = 0 + self.train_config = train_config + self.save_path = save_path def on_episode_start(self, episode: int) -> None: self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes @@ -172,12 +178,14 @@ class PerformanceEvaluator(Callback): make_experience_time_per_sample = divide(avg_make_experience_duration, num_effective_samples) learn_time_per_sample = divide(avg_learn_duration, num_effective_samples) - print_rank_0( + save_eval_result_rank_0( f"Performance summary:\n" + f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n" + f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n" + f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n" + f"Overall time per sample: {overall_time_per_sample:.2f} s\n" + f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n" - + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%" + + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%", + self.save_path, + **self.train_config, ) diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py new file mode 100755 index 000000000..cbe7d7ca8 --- /dev/null +++ b/applications/ColossalChat/coati/trainer/dpo.py @@ -0,0 +1,336 @@ +""" +Dpo trainer +""" + +from typing import Any, Optional + +import torch +from coati.models.loss import DpoLoss +from coati.models.utils import calc_masked_log_probs +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from tqdm import trange +from transformers import PreTrainedTokenizerBase + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator +from colossalai.utils import get_current_device + +from .base import SLTrainer +from .utils import is_rank_0, to_device + + +class DPOTrainer(SLTrainer): + """ + Trainer for PPO algorithm. + + Args: + actor (Actor): the actor model in ppo algorithm + ref_model (Critic): the reference model in ppo algorithm + booster (Strategy): the strategy to use for training + actor_optim (Optimizer): the optimizer to use for actor model + actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model + tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding + max_epochs (int, defaults to 1): the max number of epochs to train + beta (float, defaults to 0.1): the beta parameter in dpo loss + accumulation_steps (int): the number of steps to accumulate gradients + start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint + save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning + save_dir (str): the directory to save checkpoints + coordinator (DistCoordinator): the coordinator to use for distributed logging + """ + + def __init__( + self, + actor: Any, + ref_model: Any, + booster: Booster, + actor_optim: Optimizer, + actor_lr_scheduler: _LRScheduler, + tokenizer: PreTrainedTokenizerBase, + max_epochs: int = 1, + beta: float = 0.1, + accumulation_steps: int = 1, + start_epoch: int = 0, + save_interval: int = 0, + save_dir: str = None, + coordinator: DistCoordinator = None, + ) -> None: + super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) + self.ref_model = ref_model + self.actor_scheduler = actor_lr_scheduler + self.tokenizer = tokenizer + self.actor_loss_fn = DpoLoss(beta) + self.save_interval = save_interval + self.coordinator = coordinator + self.save_dir = save_dir + self.num_train_step = 0 + self.accumulation_steps = accumulation_steps + self.device = get_current_device() + self.accumulative_meter = AccumulativeMeanMeter() + + def _before_fit( + self, + train_preference_dataloader: DataLoader = None, + eval_preference_dataloader: DataLoader = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.train_dataloader = train_preference_dataloader + self.eval_dataloader = eval_preference_dataloader + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "dpo") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _train(self, epoch: int): + """ + Args: + epoch int: the number of current epoch + """ + self.model.train() + self.accumulative_meter.reset() + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, self.device) + ( + chosen_input_ids, + chosen_attention_mask, + chosen_loss_mask, + reject_input_ids, + reject_attention_mask, + reject_loss_mask, + ) = ( + batch["chosen_input_ids"], + batch["chosen_attention_mask"], + batch["chosen_loss_mask"], + batch["reject_input_ids"], + batch["reject_attention_mask"], + batch["reject_loss_mask"], + ) + reject_loss_mask[:, -1] = False + batch_size = chosen_input_ids.size()[0] + + actor_all_logits = self.model( + input_ids=torch.cat([chosen_input_ids, reject_input_ids]), + attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), + )["logits"].to(torch.float32) + actor_chosen_logits = actor_all_logits[:batch_size] + actor_reject_logits = actor_all_logits[batch_size:] + logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) + + logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) + + if self.ref_model is not None: + self.ref_model.eval() + with torch.no_grad(): + ref_all_logits = self.ref_model( + input_ids=torch.cat([chosen_input_ids, reject_input_ids]), + attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), + )["logits"].to(torch.float32) + ref_chosen_logits = ref_all_logits[:batch_size] + ref_reject_logits = ref_all_logits[batch_size:] + logprob_ref_chosen = calc_masked_log_probs( + ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:] + ) + logprob_ref_reject = calc_masked_log_probs( + ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] + ) + else: + logprob_ref_chosen = None + logprob_ref_reject = None + + losses, chosen_rewards, rejected_rewards = self.actor_loss_fn( + logprob_actor_chosen, + logprob_actor_reject, + logprob_ref_chosen if logprob_ref_chosen is not None else None, + logprob_ref_reject if logprob_ref_reject is not None else None, + chosen_loss_mask[:, 1:], + reject_loss_mask[:, 1:], + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float().mean() + + # DPO Loss + loss = losses.mean() + + self.booster.backward(loss=loss, optimizer=self.optimizer) + if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: + self.optimizer.step() + self.optimizer.zero_grad() + self.actor_scheduler.step() + + # sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) + rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) + reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) + + if i % self.accumulation_steps == self.accumulation_steps - 1: + self.num_train_step += 1 + step_bar.update() + # logging + if self.writer and is_rank_0(): + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar( + "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + ) + self.writer.add_scalar( + "train/rejected_rewards", + self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/margin", + self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/accuracy", + self.accumulative_meter.get("accuracy"), + self.num_train_step, + ) + self.accumulative_meter.reset() + + if (self.num_train_step + 1) % self.save_interval == 0: + # save checkpoint + self.coordinator.print_on_master("\nStart saving model checkpoint with running states") + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.actor_scheduler, + epoch=epoch, + step=i + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}" + ) + + step_bar.close() + + def _eval(self, epoch: int): + """ + Args: + epoch int: the number of current epoch + """ + if self.eval_dataloader is None: + self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") + return + self.model.eval() + self.ref_model.eval() + self.coordinator.print_on_master("\nStart evaluation...") + + step_bar = trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + + self.accumulative_meter.reset() + + with torch.no_grad(): + for i, batch in enumerate(self.eval_dataloader): + batch = to_device(batch, self.device) + ( + chosen_input_ids, + chosen_attention_mask, + chosen_loss_mask, + reject_input_ids, + reject_attention_mask, + reject_loss_mask, + ) = ( + batch["chosen_input_ids"], + batch["chosen_attention_mask"], + batch["chosen_loss_mask"], + batch["reject_input_ids"], + batch["reject_attention_mask"], + batch["reject_loss_mask"], + ) + + batch_size = chosen_input_ids.size()[0] + + actor_all_logits = self.model( + torch.cat([chosen_input_ids, reject_input_ids]), + torch.cat([chosen_attention_mask, reject_attention_mask]), + )["logits"].to(torch.float32) + actor_chosen_logits = actor_all_logits[:batch_size] + actor_reject_logits = actor_all_logits[batch_size:] + + logprob_actor_chosen = calc_masked_log_probs( + actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:] + ) + + logprob_actor_reject = calc_masked_log_probs( + actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] + ) + + self.ref_model.eval() + + ref_all_logits = self.ref_model( + torch.cat([chosen_input_ids, reject_input_ids]), + torch.cat([chosen_attention_mask, reject_attention_mask]), + )["logits"].to(torch.float32) + ref_chosen_logits = ref_all_logits[:batch_size] + ref_reject_logits = ref_all_logits[batch_size:] + logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) + logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) + + losses, chosen_rewards, rejected_rewards = self.actor_loss_fn( + logprob_actor_chosen, + logprob_actor_reject, + logprob_ref_chosen if logprob_ref_chosen is not None else None, + logprob_ref_reject if logprob_ref_reject is not None else None, + chosen_loss_mask[:, 1:], + reject_loss_mask[:, 1:], + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + loss = losses.mean() + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) + rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards) + reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) + self.accumulative_meter.add( + "margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item() + ) + step_bar.update() + + msg = "Evaluation Result:\n" + for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/ppo.py b/applications/ColossalChat/coati/trainer/ppo.py new file mode 100755 index 000000000..287767669 --- /dev/null +++ b/applications/ColossalChat/coati/trainer/ppo.py @@ -0,0 +1,403 @@ +""" +PPO trainer +""" + +import os +from typing import Dict, List, Optional + +import torch +import wandb +from coati.experience_buffer import NaiveExperienceBuffer +from coati.experience_maker import Experience, NaiveExperienceMaker +from coati.models import Critic, RewardModel +from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss +from coati.models.utils import calc_action_log_probs +from coati.trainer.callbacks import Callback +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader, DistributedSampler +from tqdm import tqdm +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin +from colossalai.cluster import DistCoordinator +from colossalai.utils import get_current_device + +from .base import OLTrainer +from .utils import CycledDataLoader, is_rank_0, to_device + + +def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict: + """ + Set default keyword arguments for generation based on the actor model. + + Args: + actor (PreTrainedModel): The actor model. + + Returns: + Dict: A dictionary containing the default keyword arguments for generation. + """ + unwrapped_model = actor.unwrap() + new_kwargs = {} + # use huggingface models method directly + if hasattr(unwrapped_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation + + if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation + return new_kwargs + + +class PPOTrainer(OLTrainer): + """ + Trainer for PPO algorithm. + + Args: + strategy (Booster): the strategy to use for training + actor (Actor): the actor model in ppo algorithm + critic (Critic): the critic model in ppo algorithm + reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences + initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor + actor_optim (Optimizer): the optimizer to use for actor model + critic_optim (Optimizer): the optimizer to use for critic model + kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss + train_batch_size (int, defaults to 8): the batch size to use for training + buffer_limit (int, defaults to 0): the max_size limitation of buffer + buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu + eps_clip (float, defaults to 0.2): the clip coefficient of policy loss + vf_coef (float, defaults to 1.0): the coefficient of value loss + ptx_coef (float, defaults to 0.9): the coefficient of ptx loss + value_clip (float, defaults to 0.4): the clip coefficient of value loss + sample_buffer (bool, defaults to False): whether to sample from buffer + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + """ + + def __init__( + self, + actor_booster: Booster, + critic_booster: Booster, + actor: PreTrainedModel, + critic: Critic, + reward_model: RewardModel, + initial_model: PreTrainedModel, + actor_optim: Optimizer, + critic_optim: Optimizer, + actor_lr_scheduler: _LRScheduler, + critic_lr_scheduler: _LRScheduler, + tokenizer: PreTrainedTokenizerBase, + kl_coef: float = 0.1, + ptx_coef: float = 0.9, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + vf_coef: float = 1.0, + value_clip: float = 0.2, + sample_buffer: bool = False, + dataloader_pin_memory: bool = True, + offload_inference_models: bool = True, + accumulation_steps: int = 1, + save_interval: int = 0, + save_dir: str = None, + use_tp: bool = False, + coordinator: DistCoordinator = None, + callbacks: List[Callback] = [], + **generate_kwargs, + ) -> None: + if isinstance(actor_booster, GeminiPlugin): + assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')" + + data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + super().__init__( + actor_booster, critic_booster, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks + ) + self.generate_kwargs = _set_default_generate_kwargs(actor) + self.generate_kwargs.update(generate_kwargs) + + self.actor = actor + self.critic = critic + self.actor_booster = actor_booster + self.critic_booster = critic_booster + self.actor_scheduler = actor_lr_scheduler + self.critic_scheduler = critic_lr_scheduler + self.tokenizer = tokenizer + self.experience_maker = NaiveExperienceMaker( + self.actor, self.critic, reward_model, initial_model, self.tokenizer, kl_coef + ) + self.train_batch_size = train_batch_size + + self.actor_loss_fn = PolicyLoss(eps_clip) + self.critic_loss_fn = ValueLoss(value_clip) + self.vf_coef = vf_coef + self.ptx_loss_fn = GPTLMLoss() + self.ptx_coef = ptx_coef + self.actor_optim = actor_optim + self.critic_optim = critic_optim + self.save_interval = save_interval + self.coordinator = coordinator + self.actor_save_dir = os.path.join(save_dir, "actor") + self.critic_save_dir = os.path.join(save_dir, "critic") + self.num_train_step = 0 + self.accumulation_steps = accumulation_steps + self.use_tp = use_tp + self.accumulative_meter = AccumulativeMeanMeter() + self.offload_inference_models = offload_inference_models + self.device = get_current_device() + + def _before_fit( + self, + prompt_dataloader: DataLoader, + pretrain_dataloader: Optional[DataLoader] = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.prompt_dataloader = CycledDataLoader(prompt_dataloader) + self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None + + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + self.wandb_run = wandb.init(project="Coati-ppo", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "ppo") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _setup_update_phrase_dataload(self): + """ + why not use distributed_dataloader? + if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks + if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank + """ + self.dataloader = DataLoader( + self.data_buffer, + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True, + pin_memory=self.dataloader_pin_memory, + collate_fn=self.data_buffer.collate_fn, + ) + + def _make_experience(self, collect_step: int) -> Experience: + """ + Make experience + """ + prompts = self.prompt_dataloader.next() + if self.offload_inference_models: + # TODO(ver217): this may be controlled by strategy if they are prepared by strategy + self.experience_maker.initial_model.to(self.device) + self.experience_maker.reward_model.to(self.device) + return self.experience_maker.make_experience( + input_ids=prompts["input_ids"].to(get_current_device()), + attention_mask=prompts["attention_mask"].to(get_current_device()), + **self.generate_kwargs, + ) + + def _training_step(self, experience: Experience): + """ + Args: + experience: + sequences: [batch_size, prompt_length + response_length] --- ............ + """ + self.num_train_step += 1 + self.actor.train() + self.critic.train() + num_actions = experience.action_log_probs.size(1) + # policy loss + + actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[ + "logits" + ] # [batch size, prompt_length + response_length] + action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions) + + actor_loss, to_skip, max_ratio = self.actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) + actor_loss = (1 - self.ptx_coef) * actor_loss + if not to_skip: + self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim) + + # ptx loss + if self.ptx_coef != 0: + batch = self.pretrain_dataloader.next() + batch = to_device(batch, self.device) + outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + ptx_loss = outputs.loss + ptx_loss = self.ptx_coef * ptx_loss + self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim) + + # value loss + values = self.critic( + input_ids=experience.sequences, attention_mask=experience.attention_mask + ) # [batch size, prompt_length + response_length] + critic_loss = self.critic_loss_fn( + values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask + ) + critic_loss = critic_loss * self.vf_coef + self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim) + + # sync + actor_loss_mean = all_reduce_mean(tensor=actor_loss) + critic_loss_mean = all_reduce_mean(tensor=critic_loss) + max_ratio_mean = all_reduce_mean(tensor=max_ratio) + reward_mean = all_reduce_mean(tensor=experience.reward.mean()) + value_mean = all_reduce_mean(tensor=experience.values.mean()) + advantages_mean = all_reduce_mean(tensor=experience.advantages.mean()) + kl_mean = all_reduce_mean(tensor=experience.kl.mean()) + if self.ptx_coef != 0: + ptx_loss_mean = all_reduce_mean(tensor=ptx_loss) + + self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("critic_loss", critic_loss_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item()) + self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("value", value_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item()) + self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0) + self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item()) + if self.ptx_coef != 0: + self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item()) + + if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1: + self.actor_optim.step() + self.critic_optim.step() + self.actor_optim.zero_grad() + self.critic_optim.zero_grad() + self.actor_scheduler.step() + self.critic_scheduler.step() + + # preparing logging model output and corresponding rewards. + if self.num_train_step % 10 == 1: + response_text = self.experience_maker.tokenizer.batch_decode( + experience.sequences, skip_special_tokens=True + ) + for i in range(len(response_text)): + response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}" + + if self.writer and is_rank_0() and "wandb_run" in self.__dict__: + # log output to wandb + my_table = wandb.Table( + columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text] + ) + try: + self.wandb_run.log({"sample_response": my_table}) + except OSError as e: + self.coordinator.print_on_master(e) + elif self.writer and is_rank_0(): + for line in response_text: + self.coordinator.print_on_master(line) + + if self.writer and is_rank_0(): + self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), self.num_train_step) + self.writer.add_scalar( + "train/skip_ratio", self.accumulative_meter.get("skip_ratio"), self.num_train_step + ) + self.writer.add_scalar( + "train/actor_loss", self.accumulative_meter.get("actor_loss"), self.num_train_step + ) + self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar("train/lr_critic", self.critic_optim.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar( + "train/critic_loss", self.accumulative_meter.get("critic_loss"), self.num_train_step + ) + if self.ptx_coef != 0: + self.writer.add_scalar( + "train/ptx_loss", self.accumulative_meter.get("ptx_loss"), self.num_train_step + ) + self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), self.num_train_step) + self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), self.num_train_step) + self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step) + self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step) + self.accumulative_meter.reset() + + def _learn(self, update_step: int): + """ + Perform the learning step of the PPO algorithm. + + Args: + update_step (int): The current update step. + + Returns: + None + """ + if self.offload_inference_models: + self.experience_maker.initial_model.to("cpu") + self.experience_maker.reward_model.to("cpu") + + # buffer may be empty at first, we should rebuild at each training + if self.sample_buffer: + experience = self.data_buffer.sample() + self._on_learn_batch_start() + experience.to_device(self.device) + self._training_step(experience) + self._on_learn_batch_end(experience) + else: + if isinstance(self.dataloader.sampler, DistributedSampler): + self.dataloader.sampler.set_epoch(update_step) + pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0()) + for experience in pbar: + self._on_learn_batch_start() + experience.to_device(self.device) + self._training_step(experience) + self._on_learn_batch_end(experience) + + def _save_checkpoint(self, episode: int = 0): + """ + Save the actor and critic checkpoints with running states. + + Args: + episode (int): The current episode number. + + Returns: + None + """ + + self.coordinator.print_on_master("\nStart saving actor checkpoint with running states") + save_checkpoint( + save_dir=self.actor_save_dir, + booster=self.actor_booster, + model=self.actor, + optimizer=self.actor_optim, + lr_scheduler=self.actor_scheduler, + epoch=0, + step=episode + 1, + batch_size=self.train_batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved actor checkpoint at episode {(episode + 1)} at folder {self.actor_save_dir}" + ) + + self.coordinator.print_on_master("\nStart saving critic checkpoint with running states") + save_checkpoint( + save_dir=self.critic_save_dir, + booster=self.critic_booster, + model=self.critic, + optimizer=self.critic_optim, + lr_scheduler=self.critic_scheduler, + epoch=0, + step=episode + 1, + batch_size=self.train_batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved critic checkpoint at episode {(episode + 1)} at folder {self.critic_save_dir}" + ) diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py new file mode 100755 index 000000000..0fb714a62 --- /dev/null +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -0,0 +1,242 @@ +""" +Reward model trianer +""" + +import os +from typing import Any, Callable, Optional + +import torch +import tqdm +from coati.models import LogSigLoss +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizerBase + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator +from colossalai.utils import get_current_device + +from .base import SLTrainer +from .utils import is_rank_0, to_device + + +class RewardModelTrainer(SLTrainer): + """ + Trainer for PPO algorithm. + + Args: + actor (Actor): the actor model in ppo algorithm + ref_model (Critic): the reference model in ppo algorithm + booster (Strategy): the strategy to use for training + actor_optim (Optimizer): the optimizer to use for actor model + actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model + tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding + max_epochs (int, defaults to 1): the max number of epochs to train + beta (float, defaults to 0.1): the beta parameter in dpo loss + accumulation_steps (int): the number of steps to accumulate gradients + start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint + save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning + save_dir (str): the directory to save checkpoints + coordinator (DistCoordinator): the coordinator to use for distributed logging + """ + + def __init__( + self, + model: Any, + booster: Booster, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + tokenizer: PreTrainedTokenizerBase, + loss_fn: Optional[Callable] = None, + max_epochs: int = 1, + beta: float = 0.1, + accumulation_steps: int = 1, + start_epoch: int = 0, + save_interval: int = 0, + save_dir: str = None, + coordinator: DistCoordinator = None, + ) -> None: + super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch) + self.actor_scheduler = lr_scheduler + self.tokenizer = tokenizer + self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta) + self.save_interval = save_interval + self.coordinator = coordinator + self.save_dir = save_dir + self.num_train_step = 0 + self.accumulation_steps = accumulation_steps + self.device = get_current_device() + self.accumulative_meter = AccumulativeMeanMeter() + + def _before_fit( + self, + train_preference_dataloader: DataLoader = None, + eval_preference_dataloader: DataLoader = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + prompt_dataloader (DataLoader): the dataloader to use for prompt data + pretrain_dataloader (DataLoader): the dataloader to use for pretrain data + """ + self.train_dataloader = train_preference_dataloader + self.eval_dataloader = eval_preference_dataloader + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + self.wandb_run = wandb.init(project="Coati-rm", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "rm") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _train(self, epoch): + self.model.train() + step_bar = tqdm.trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, self.device) + + ( + chosen_input_ids, + chosen_attention_mask, + reject_input_ids, + reject_attention_mask, + ) = ( + batch["chosen_input_ids"], + batch["chosen_attention_mask"], + batch["reject_input_ids"], + batch["reject_attention_mask"], + ) + batch_size = chosen_input_ids.size()[0] + + # Concatenate for better parrallelism + reward = self.model( + torch.cat([chosen_input_ids, reject_input_ids], dim=0), + attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask], dim=0), + ) + chosen_reward = reward[:batch_size] + reject_reward = reward[batch_size:] + loss = self.loss_fn(chosen_reward, reject_reward).mean() + + self.booster.backward(loss=loss, optimizer=self.optimizer) + + accuracy = (chosen_reward > reject_reward).float() + + # Sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_reward) + rejected_rewards_mean = all_reduce_mean(tensor=reject_reward) + accuracy_mean = all_reduce_mean(tensor=accuracy) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item()) + + if (i + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.actor_scheduler.step() + step_bar.update() + self.num_train_step += 1 + + # Logging + if self.writer and is_rank_0(): + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step) + self.writer.add_scalar( + "train/dist", + self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"), + self.num_train_step, + ) + self.writer.add_scalar( + "train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), self.num_train_step + ) + self.writer.add_scalar( + "train/reward_reject", self.accumulative_meter.get("rejected_rewards"), self.num_train_step + ) + self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), self.num_train_step) + + self.accumulative_meter.reset() + + # Save checkpoint + if self.save_interval > 0 and (self.num_train_step + 1) % self.save_interval == 0: + self.coordinator.print_on_master("\nStart saving model checkpoint with running states") + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.actor_scheduler, + epoch=epoch, + step=i + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}" + ) + step_bar.close() + + def _eval(self, epoch): + if self.eval_dataloader is None: + self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") + return + self.model.eval() + step_bar = tqdm.trange( + len(self.eval_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0() + ) + with torch.no_grad(): + for i, batch in enumerate(self.eval_dataloader): + batch = to_device(batch, self.device) + ( + chosen_input_ids, + chosen_attention_mask, + reject_input_ids, + reject_attention_mask, + ) = ( + batch["chosen_input_ids"], + batch["chosen_attention_mask"], + batch["reject_input_ids"], + batch["reject_attention_mask"], + ) + + chosen_reward = self.model(chosen_input_ids, attention_mask=chosen_attention_mask) + reject_reward = self.model(reject_input_ids, attention_mask=reject_attention_mask) + loss = self.loss_fn(chosen_reward, reject_reward).mean() + + # Sync + loss_mean = all_reduce_mean(tensor=loss) + chosen_rewards_mean = all_reduce_mean(tensor=chosen_reward) + rejected_rewards_mean = all_reduce_mean(tensor=reject_reward) + self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + + step_bar.update() + + msg = "Evaluation Result:\n" + for tag in ["loss", "chosen_rewards", "rejected_rewards"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + msg = ( + msg + + f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n" + ) + self.coordinator.print_on_master(msg) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/sft.py b/applications/ColossalChat/coati/trainer/sft.py new file mode 100755 index 000000000..c95f5b65a --- /dev/null +++ b/applications/ColossalChat/coati/trainer/sft.py @@ -0,0 +1,170 @@ +""" +SFT trainer +""" + +import os +from typing import Optional + +import torch +from coati.trainer.utils import all_reduce_mean +from coati.utils import AccumulativeMeanMeter, save_checkpoint +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from tqdm import trange + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + +from .base import SLTrainer +from .utils import is_rank_0, to_device + + +class SFTTrainer(SLTrainer): + """ + Trainer to use while training reward model. + + Args: + model (torch.nn.Module): the model to train + strategy (Strategy): the strategy to use for training + optim(Optimizer): the optimizer to use for training + lr_scheduler(_LRScheduler): the lr scheduler to use for training + max_epochs (int, defaults to 2): the number of epochs to train + accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients + """ + + def __init__( + self, + model, + booster: Booster, + optim: Optimizer, + lr_scheduler: _LRScheduler, + max_epochs: int = 2, + accumulation_steps: int = 8, + start_epoch=0, + save_interval: int = None, + save_dir: str = None, + coordinator: Optional[DistCoordinator] = None, + ) -> None: + super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch) + + self.accumulation_steps = accumulation_steps + self.scheduler = lr_scheduler + self.save_interval = save_interval + self.save_dir = save_dir + self.coordinator = coordinator + self.num_train_step = 0 + self.num_eval_step = 0 + self.accumulative_meter = AccumulativeMeanMeter() + + def _before_fit( + self, + train_dataloader: DataLoader, + eval_dataloader: Optional[DataLoader] = None, + log_dir: Optional[str] = None, + use_wandb: bool = False, + ): + """ + Args: + train_dataloader: the dataloader to use for training + eval_dataloader: the dataloader to use for evaluation + log_dir: the directory to save logs + use_wandb: whether to use wandb for logging + """ + self.train_dataloader = train_dataloader + self.eval_dataloader = eval_dataloader + + self.writer = None + if use_wandb and is_rank_0(): + assert log_dir is not None, "log_dir must be provided when use_wandb is True" + import wandb + + wandb.init(project="Coati-sft", sync_tensorboard=True) + if log_dir is not None and is_rank_0(): + import os + import time + + from torch.utils.tensorboard import SummaryWriter + + log_dir = os.path.join(log_dir, "sft") + log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + self.writer = SummaryWriter(log_dir=log_dir) + + def _train(self, epoch: int): + self.model.train() + step_bar = trange( + len(self.train_dataloader) // self.accumulation_steps, + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for i, batch in enumerate(self.train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + batch_size = batch["input_ids"].size(0) + outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + loss = outputs.loss + self.booster.backward(loss=loss, optimizer=self.optimizer) + + loss_mean = all_reduce_mean(tensor=loss) + self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) + + # Gradient accumulation + if (i + 1) % self.accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + if self.writer: + self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) + self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) + self.num_train_step += 1 + self.accumulative_meter.reset() + step_bar.update() + + # Save checkpoint + if ( + self.save_dir is not None + and self.save_interval is not None + and (self.num_train_step + 1) % self.save_interval == 0 + ): + save_checkpoint( + save_dir=self.save_dir, + booster=self.booster, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.scheduler, + epoch=epoch, + step=self.num_train_step + 1, + batch_size=batch_size, + coordinator=self.coordinator, + ) + self.coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" + ) + step_bar.close() + + def _eval(self, epoch: int): + if self.eval_dataloader is None: + self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") + return + self.accumulative_meter.reset() + self.model.eval() + with torch.no_grad(): + step_bar = trange( + len(self.eval_dataloader), + desc=f"Epoch {epoch + 1}/{self.max_epochs}", + disable=not is_rank_0(), + ) + for batch in self.eval_dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) + loss_mean = all_reduce_mean(tensor=outputs.loss) + self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) + step_bar.update() + loss_mean = self.accumulative_meter.get("loss") + msg = "Evaluation Result:\n" + for tag in ["loss"]: + msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" + self.coordinator.print_on_master(msg) + with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: + f.write(msg) + step_bar.close() diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py new file mode 100755 index 000000000..5ce1e9ef0 --- /dev/null +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -0,0 +1,113 @@ +""" +Training utilities for Coati. +""" +from typing import Any + +import torch +import torch.distributed as dist +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader + + +class CycledDataLoader: + """ + A data loader that cycles through the data when it reaches the end. + + Args: + dataloader (DataLoader): The original data loader. + + Attributes: + dataloader (DataLoader): The original data loader. + count (int): The number of times the data loader has been cycled. + dataloader_iter (iterable): The iterator for the data loader. + + Methods: + next(): Returns the next batch of data from the data loader, cycling through the data if necessary. + """ + + def __init__( + self, + dataloader: DataLoader, + ) -> None: + self.dataloader = dataloader + + self.count = 0 + self.dataloader_iter = None + + def next(self): + """ + Returns the next batch of data from the data loader, cycling through the data if necessary. + + Returns: + Any: The next batch of data from the data loader. + """ + # defer initialization + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + + self.count += 1 + try: + return next(self.dataloader_iter) + except StopIteration: + self.count = 0 + self.dataloader_iter = iter(self.dataloader) + return next(self.dataloader_iter) + + +def is_rank_0() -> bool: + """ + Check if the current process is the rank 0 process in a distributed training setup. + + Returns: + bool: True if the current process is the rank 0 process, False otherwise. + """ + return not dist.is_initialized() or dist.get_rank() == 0 + + +def to_device(x: Any, device: torch.device) -> Any: + """ + Move the input tensor or nested structure of tensors to the specified device. + + Args: + x (Any): The input tensor or nested structure of tensors. + device (torch.device): The target device to move the tensors to. + + Returns: + Any: The tensor or nested structure of tensors moved to the target device. + """ + + def _to(t: Any): + if isinstance(t, torch.Tensor): + return t.to(device) + return t + + return tree_map(_to, x) + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + """ + Perform all-reduce operation on the given tensor and compute the mean across all processes. + + Args: + tensor (torch.Tensor): The input tensor to be reduced. + + Returns: + torch.Tensor: The reduced tensor with mean computed across all processes. + """ + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: + """ + Performs an all-reduce operation to sum the values of the given tensor across all processes. + + Args: + tensor (torch.Tensor): The input tensor to be reduced. + + Returns: + torch.Tensor: The reduced tensor with the sum of values across all processes. + """ + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + return tensor diff --git a/applications/ColossalChat/coati/utils/__init__.py b/applications/ColossalChat/coati/utils/__init__.py new file mode 100755 index 000000000..5d651b202 --- /dev/null +++ b/applications/ColossalChat/coati/utils/__init__.py @@ -0,0 +1,4 @@ +from .accumulative_meter import AccumulativeMeanMeter +from .ckpt_io import load_checkpoint, save_checkpoint + +__all__ = ["load_checkpoint", "save_checkpoint", "AccumulativeMeanMeter"] diff --git a/applications/ColossalChat/coati/utils/accumulative_meter.py b/applications/ColossalChat/coati/utils/accumulative_meter.py new file mode 100755 index 000000000..9ed662d2d --- /dev/null +++ b/applications/ColossalChat/coati/utils/accumulative_meter.py @@ -0,0 +1,69 @@ +""" +A class that can be used to calculate the mean of a variable +""" + + +class AccumulativeMeanVariable: + """ + A class that calculates the accumulative mean of a variable. + """ + + def __init__(self): + self._sum = 0 + self._count = 0 + + def add(self, value, count_update=1): + """ + Adds a value to the sum and updates the count. + + Args: + value (float): The value to be added. + count_update (int, optional): The amount to update the count by. Defaults to 1. + """ + self._sum += value + self._count += count_update + + def get(self): + """ + Calculates and returns the accumulative mean. + + Returns: + float: The accumulative mean. + """ + return self._sum / self._count if self._count > 0 else 0 + + def reset(self): + """ + Resets the sum and count to zero. + """ + self._sum = 0 + self._count = 0 + + +class AccumulativeMeanMeter: + """ + A class for calculating and storing the accumulative mean of variables. + + Attributes: + variable_dict (dict): A dictionary to store the accumulative mean variables. + + Methods: + add(name, value, count_update=1): Adds a value to the specified variable. + get(name): Retrieves the accumulative mean value of the specified variable. + reset(): Resets all the accumulative mean variables to their initial state. + """ + + def __init__(self): + self.variable_dict = {} + + def add(self, name, value, count_update=1): + if name not in self.variable_dict: + self.variable_dict[name] = AccumulativeMeanVariable() + self.variable_dict[name].add(value, count_update=count_update) + + def get(self, name): + return self.variable_dict[name].get() + + def reset(self): + for name in self.variable_dict: + self.variable_dict[name].reset() diff --git a/applications/ColossalChat/coati/utils/ckpt_io.py b/applications/ColossalChat/coati/utils/ckpt_io.py new file mode 100755 index 000000000..5b804f0ac --- /dev/null +++ b/applications/ColossalChat/coati/utils/ckpt_io.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Helper functions for IO save load checkpoints +""" + +import json +import os +from typing import Any, Dict, Tuple, Union + +import torch +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def save_checkpoint( + save_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, +) -> None: + """ + Save model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") + os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + + """ + Temporary disable the following as save_optimizer causes all processes to hang in a multi-gpu environment, + working on fixing this bug + """ + + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load_checkpoint( + load_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, +) -> Tuple[int, int, int]: + """ + Load model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + # Update booster params states. + booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling")) + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + + running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) + return ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) diff --git a/applications/ColossalChat/config/conversation_template/Qwen.json b/applications/ColossalChat/config/conversation_template/Qwen.json new file mode 100644 index 000000000..09f706ffe --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/Qwen.json @@ -0,0 +1,7 @@ +{ + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + null + ] +} diff --git a/applications/ColossalChat/config/conversation_template/Vicuna.json b/applications/ColossalChat/config/conversation_template/Vicuna.json new file mode 100644 index 000000000..2b00b6529 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/Vicuna.json @@ -0,0 +1,7 @@ +{ + "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\\'t know the answer to a question, please don\\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ] +} diff --git a/applications/ColossalChat/config/conversation_template/Yi.json b/applications/ColossalChat/config/conversation_template/Yi.json new file mode 100644 index 000000000..9716413b5 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/Yi.json @@ -0,0 +1,7 @@ +{ + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ] +} diff --git a/applications/ColossalChat/config/conversation_template/chatGLM2.json b/applications/ColossalChat/config/conversation_template/chatGLM2.json new file mode 100644 index 000000000..a2638dbe7 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/chatGLM2.json @@ -0,0 +1,7 @@ +{ + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ] +} diff --git a/applications/ColossalChat/config/conversation_template/colossal-llama2.json b/applications/ColossalChat/config/conversation_template/colossal-llama2.json new file mode 100644 index 000000000..cc7f1e5d7 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/colossal-llama2.json @@ -0,0 +1,7 @@ +{ + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant: ' + bos_token }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ] +} diff --git a/applications/ColossalChat/config/conversation_template/llama2.json b/applications/ColossalChat/config/conversation_template/llama2.json new file mode 100644 index 000000000..80558f976 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/llama2.json @@ -0,0 +1,7 @@ +{ + "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ] +} diff --git a/applications/ColossalChat/config/conversation_template/mistral.json b/applications/ColossalChat/config/conversation_template/mistral.json new file mode 100644 index 000000000..b48c3a3f2 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/mistral.json @@ -0,0 +1,7 @@ +{ + "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + "system_message": null, + "stop_ids": [ + 2 + ] +} diff --git a/applications/ColossalChat/config/conversation_template/zephyr.json b/applications/ColossalChat/config/conversation_template/zephyr.json new file mode 100644 index 000000000..2ab141111 --- /dev/null +++ b/applications/ColossalChat/config/conversation_template/zephyr.json @@ -0,0 +1,7 @@ +{ + "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 2 + ] +} diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md new file mode 100755 index 000000000..cfed3f1f3 --- /dev/null +++ b/applications/ColossalChat/examples/README.md @@ -0,0 +1,565 @@ +# Examples + +## Table of Contents + +- [Examples](#examples) + - [Table of Contents](#table-of-contents) + - [Install Requirements](#install-requirements) + - [Get Start with ColossalRun](#get-start-with-colossalrun) + - [Training Configuration](#training-configuration) + - [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning) + - [Step 1: Data Collection](#step-1-data-collection) + - [Step 2: Preprocessing](#step-2-preprocessing) + - [Step 3: Training](#step-3-training) + - [RLHF Stage 2: Training Reward Model](#rlhf-training-stage2---training-reward-model) + - [Step 1: Data Collection](#step-1-data-collection-1) + - [Step 2: Preprocessing](#step-2-preprocessing-1) + - [Step 3: Training](#step-3-training-1) + - [Features and Tricks in RM Training](#features-and-tricks-in-rm-training) + - [RLHF Stage 3: Proximal Policy Optimization](#rlhf-training-stage3---proximal-policy-optimization) + - [Step 1: Data Collection](#step-1-data-collection-2) + - [Step 2: Preprocessing](#step-2-preprocessing-2) + - [Step 3: Training](#step-3-training-3) + - [PPO Training Results](#sample-training-results-using-default-script) + - [Reward](#reward) + - [KL Divergence](#approximate-kl-divergence) + - [Note on PPO Training](#note-on-ppo-training) + - [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization) + - [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning) + - [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training) + - [Hardware Requirements](#hardware-requirements) + - [Inference example](#inference-example) + - [Attention](#attention) + +--- + +## Install requirements + +```shell +pip install -r requirements.txt +``` + + +## Get Start with ColossalRun + +You can use colossalai run to launch multi-nodes training: +``` +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +train.py --OTHER_CONFIGURATIONS +``` +Here is a sample hostfile: + +``` +hostname1 +hostname2 +hostname3 +hostname4 +``` + +Make sure master node can access all nodes (including itself) by ssh without password. Here are some other arguments. + +- nnodes: number of nodes used in the training +- nproc-per-node: specifies the number of processes to be launched per node +- rdzv-endpoint: address of the host node + +### Training Configuration + +This section gives a simple introduction on different training strategies that you can use and how to use them with our boosters and plugins to reduce training time and VRAM consumption. For more detail regarding training strategies, please refer to [here](https://colossalai.org/docs/concepts/paradigms_of_parallelism). For details regarding boosters and plugins, please refer to [here](https://colossalai.org/docs/basics/booster_plugins). + + +
                  Gemini + +This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk). + +Below shows how to use the gemini in SFT training. +``` +colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --save_interval 5000 \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --plugin gemini \ + --batch_size 4 \ + --max_epochs 1 \ + --accumulation_steps 1 \ # the gradient accumulation has to be disabled + --lr 2e-5 \ + --max_len 2048 \ + --use_wandb +``` + +
                  + +
                  Gemini-Auto + +This option use gemini and will automatically offload tensors with low priority to cpu. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](https://colossalai.org/docs/features/zero_with_chunk). + +Below shows how to use the gemin-auto in SFT training. +``` +colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --save_interval 5000 \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --plugin gemini_auto \ + --batch_size 4 \ + --max_epochs 1 \ + --accumulation_steps 1 \ # the gradient accumulation has to be disabled + --lr 2e-5 \ + --max_len 2048 \ + --use_wandb +``` + +
                  + + + +
                  Zero2 + +This option will distribute the optimizer parameters and the gradient to multiple GPUs and won't offload weights to cpu. It uses reduce and gather to synchronize gradients and weights. It does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost. That is to say, it's not a good idea to use Zero-2 with pipeline parallelism. + +Below shows how to use the zero2 in SFT training. +``` +colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --save_interval 5000 \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --plugin zero2 \ + --batch_size 4 \ + --max_epochs 1 \ + --accumulation_steps 4 \ + --lr 2e-5 \ + --max_len 2048 \ + --use_wandb +``` + +
                  + + +
                  Zero2CPU + +This option will distribute the optimizer parameters and the gradient to multiple GPUs as well as offload parameters to cpu. It does not support local gradient accumulation. Though you can accumulate gradient if you insist, it cannot reduce communication cost. + +Below shows how to use the zero2-cpu in SFT training. +``` +colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --save_interval 5000 \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --plugin zero2_cpu \ + --batch_size 4 \ + --max_epochs 1 \ + --accumulation_steps 4 \ + --lr 2e-5 \ + --max_len 2048 \ + --use_wandb +``` + +
                  + +
                  Tensor Parallelism + +This option support Tensor Parallelism (TP). Note that if you want to use TP, zero and pipeline parallelism will be disabled. TP split large model weights/optimizer parameters/gradients into multiple small ones and distributes them to multiple GPUs, hence it is recommended to use TP when your model is large (e.g. 20B and above) or your training algorithm consumes a lot of memory (e.g. PPO). + +Below shows how to use the TP in PPO training. +``` +colossalai run --nproc_per_node 4 --hostfile hostfile --master_port 30039 train_ppo.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --rm_pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --rm_checkpoint_path $REWARD_MODEL_PATH \ + --prompt_dataset ${prompt_dataset[@]} \ + --pretrain_dataset ${ptx_dataset[@]} \ + --ptx_batch_size 1 \ + --ptx_coef 0.0 \ + --plugin "zero2" \ + --save_interval 200 \ + --save_path $SAVE_DIR \ + --num_episodes 2000 \ + --num_collect_steps 4 \ + --num_update_steps 1 \ + --experience_batch_size 8 \ + --train_batch_size 4 \ + --accumulation_steps 8 \ + --tp 4 \ # TP size, nproc_per_node must be divisible by it + --lr 9e-6 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --weight_decay 0.01 \ + --warmup_steps 100 \ + --grad_checkpoint \ + --use_wandb +``` + +
                  + + +
                  Gradient Checkpointing + +This option saves VRAM consumption by selectively recomputing some of the intermediate value on-the-fly during the backward pass, rather than storing them in memory. + +To enable gradient checkpointing, add --grad_checkpoint to your training script. +``` +colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --save_interval 5000 \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --plugin zero2_cpu \ + --batch_size 4 \ + --max_epochs 1 \ + --accumulation_steps 4 \ + --lr 2e-5 \ + --max_len 2048 \ + --grad_checkpoint \ # This enables gradient checkpointing + --use_wandb +``` + +
                  + +
                  Flash Attention + +Details about flash attention can be found in the paper: [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135). + +To enable flash attention, add --use_flash_attn to your training script. +``` +colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --save_interval 5000 \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --plugin zero2_cpu \ + --batch_size 4 \ + --max_epochs 1 \ + --accumulation_steps 4 \ + --lr 2e-5 \ + --max_len 2048 \ + --use_flash_attn \ # This enables flash attention + --use_wandb +``` + +
                  + +
                  Low Rank Adaption + +Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduce the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources. + +To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64). +``` +colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --save_interval 5000 \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --plugin zero2_cpu \ + --batch_size 4 \ + --max_epochs 1 \ + --accumulation_steps 4 \ + --lr 2e-5 \ + --max_len 2048 \ + --lora_rank 32 \ # This enables LoRA + --use_wandb +``` + +
                  + +
                  Other Training Arguments + +- grad_clip: gradient larger than this value will be clipped. +- weight_decay: weight decay hyper-parameter. +- warmup_steps: number of warmup steps used in setting up the learning rate scheduler. +- pretrain: pretrain model path, weights will be loaded from this pretrained model unless checkpoint_path is provided. +- tokenizer_dir: specify where to load the tokenizer, if not provided, tokenizer will be loaded from pretrain model path. +- dataset: a list of strings, each is a path to a folder contains buffered dataset files in arrow format. +- checkpoint_path: if provided, will load weights from the checkpoint_path. +- config_file: path to store the training config file. +- save_dir: path to store the model checkpoints. +- max_length: input will be padded/truncate to max_length before feeding to the model. +- max_epochs: number of epoch to train. +- batch_size: training batch size. +- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some device may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility. +- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes. +- merge_lora_weights: whether to merge lora weights before saving the model +- lr: the learning rate used in training. +- accumulation_steps: accumulate gradient every accumulation_steps. +- log_dir: path to store the log. +- use_wandb: if this flag is up, you can view logs on wandb. + +
                  + +### RLHF Training Stage1 - Supervised Instructs Tuning + +Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat: + +#### Step 1: Data Collection +The first step in Stage 1 is to collect a dataset of human demonstrations of the following format. + +```json +[ + {"messages": + [ + { + "from": "human", + "content": "what are some pranks with a pen i can do?" + }, + { + "from": "assistant", + "content": "Are you looking for practical joke ideas?" + }, + ... + ] + }, + ... +] +``` + +#### Step 2: Preprocessing +Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization. + +In this code we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the following steps to define your chat template and preprocess your data. + +- Step 1: (Optional). Define your conversation template. You need to provide a conversation template config file similar to the config files under the ./config/conversation_template directory. This config should include the following fields. + ```json + { + "chat_template": (Optional), A string of chat_template used for formatting chat data. If not set (None), will use the default chat template of the provided tokenizer. If a path to a huggingface model or local model is provided, will use the chat_template of that model. To use a custom chat template, you need to manually set this field. For more details on how to write a chat template in Jinja format, please read https://huggingface.co/docs/transformers/main/chat_templating, + "system_message": A string of system message to be added at the beginning of the prompt. If no is provided (None), no system message will be added, + "stop_ids": (Optional), A list of string indicating the end of assistant's response during the rollout stage of PPO training. It's recommended to set this manually for PPO training. If not set, will set to tokenizer.eos_token_ids automatically, + } + ``` + On your first run of the data preparation script, you only need to define the "chat_template" (if you want to use custom chat template) and the "system message" (if you want to use a custom system message), + +- Step 2: Run the data preparation script--- [prepare_sft_dataset.sh](./examples/data_preparation_scripts/prepare_sft_dataset.sh). Note that whether or not you have skipped the first step, you need to provide the path to the conversation template config file (via the conversation_template_config arg). If you skipped the first step, an auto-generated conversation template will be stored at the designated file path. + +- Step 3: (Optional) Check the correctness of the processed data. We provided an easy way for you to do a manual checking on the processed data by checking the "$SAVE_DIR/jsonl/part-XXXX.jsonl" files. + +Finishing the above steps, you have converted the raw conversation to the designated chat format and tokenized the formatted conversation, calculate input_ids, labels, attention_masks and buffer those into binary dataset files under "$SAVE_DIR/arrow/part-XXXX" folders. + +For example, our Colossal-LLaMA-2 format looks like, +``` + A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + +Human: what are some pranks with a pen i can do? Assistant: Are you looking for practical joke ideas? +... +``` + +#### Step 3: Training +Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. + +### RLHF Training Stage2 - Training Reward Model + +Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model. + +#### Step 1: Data Collection +Below shows the preference dataset format used in training the reward model. + +```json +[ + {"context": [ + { + "from": "human", + "content": "Introduce butterflies species in Oregon." + } + ] + "chosen": [ + { + "from": "assistant", + "content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..." + }, + ... + ], + "rejected": [ + { + "from": "assistant", + "content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..." + }, + ... + ] + }, + ... +] +``` + +#### Step 2: Preprocessing +Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training. + +#### Step 3: Training +You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. + +#### Features and Tricks in RM Training + +- We recommend using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets for training the reward model. +- We support 2 kinds of loss function named `log_sig`(used by OpenAI) and `log_exp`(used by Anthropic). +- We log the training accuracy `train/acc`, `reward_chosen` and `reward_rejected` to monitor progress during training. +- We use cosine-reducing lr-scheduler for RM training. +- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution. + +#### Note on Reward Model Training + +Before you move on the next stage, please check the following list to ensure that your reward model is stable and robust. You can check the reward chart and the accuracy chart on wandb. +- The mean reward for chosen data is much higher than those for rejected data +- The accuracy is larger than 0.5 by a significant margin (usually should be greater than 0.6) +- Optional:check the reward is positive for chosen data vice versa + +Your training reward curves should look similar to the following charts. +

                  +image +

                  + +### RLHF Training Stage3 - Proximal Policy Optimization + +In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process: + +

                  + +

                  + +#### Step 1: Data Collection +PPO uses two kind of training data--- the prompt data and the pretrain data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format. + +```json +[ + {"messages": + [ + { + "from": "human", + "content": "what are some pranks with a pen i can do?" + } + ... + ] + }, +] +``` + +The second dataset--- pretrained dataset is optional, provide it if you want to use the ptx loss introduced in the [InstructGPT paper](https://arxiv.org/abs/2203.02155). It follows the following format. + +```json + [ + { + "source": "", # system instruction + "Target": "Provide a list of the top 10 most popular mobile games in Asia\nThe top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + }, + ... + ] + ``` +#### Step 2: Preprocessing +To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh) + +You can use the SFT dataset you prepared in the SFT stage or prepare a new one from different source for the ptx dataset. The ptx data is used to calculate ptx loss, which stablize the training according to the [InstructGPT paper](https://arxiv.org/pdf/2203.02155.pdf). + +#### Step 3: Training +You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. + +```bash +--pretrain $PRETRAINED_MODEL_PATH \ +--rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectural +--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ +--rm_checkpoint_path $REWARD_MODEL_PATH \ # reward model checkpoint path +--prompt_dataset ${prompt_dataset[@]} \ # List of string, prompt dataset +--conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \ # path to the conversation template config file +--pretrain_dataset ${ptx_dataset[@]} \ # List of string, the sft dataset +--ptx_batch_size 1 \ # batch size for calculate ptx loss +--ptx_coef 0.0 \ # none-zero if ptx loss is enable +--num_episodes 2000 \ # number of episodes to train +--num_collect_steps 1 \ +--num_update_steps 1 \ +--experience_batch_size 8 \ +--train_batch_size 4 \ +--accumulation_steps 2 +``` + +Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic. + +- Without tensor parallelism, +``` +experience buffer size += num_process * num_collect_steps * experience_batch_size += train_batch_size * accumulation_steps * num_process +``` + +- With tensor parallelism, +``` +num_tp_group = num_process / tp +experience buffer size += num_tp_group * num_collect_steps * experience_batch_size += train_batch_size * accumulation_steps * num_tp_group +``` + +### Sample Training Results Using Default Script +#### Reward +

                  +image +

                  + +### Note on PPO Training +#### Q1: My reward is negative +Answer: Check your reward model trained in stage 1. If the reward model only generate negative reward, we actually will expect a negative reward. However, even though the reward is negative, the reward should go up. + +#### Q2: My actor loss is negative +Answer: This is normal for actor loss as PPO doesn't restrict the actor loss to be positive. + +#### Q3: My reward doesn't go up (decreases) +Answer: The causes to this problem are two-fold. Check your reward model, make sure that it gives positive and strong reward for good cases and negative, strong reward for bad responses. You should also try different hyperparameter settings. + +#### Q4: Generation is garbage +Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to a none-zero value (between 0 and 1), which balances PPO loss and sft loss. + +## Alternative Option For RLHF: Direct Preference Optimization + +For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. + +### DPO Training Stage1 - Supervised Instructs Tuning + +Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part. + +### DPO Training Stage2 - DPO Training +#### Step 1: Data Collection & Preparation +For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training. + +#### Step 2: Training +You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. Please refer to the [training configuration](#training-configuration) section for details regarding supported training options. + +#### DPO Result +

                  +image +

                  + +## Hardware Requirements +For PPO, we suggest using Tensor Parallelism. The following table shows the VRAM consumption of training a 7B model on a dummy dataset with 2048 sequence length and 512 layout length with different tp_size (equal to the number of GPUs). In this experiment, we use H800 GPU with 80GB VRAM. +| PPO | tp=8 | tp=4 | +|-------|---------------|---------------| +| bs=1 | 18485.19 MB | 42934.45 MB | +| bs=4 | 25585.65 MB | 42941.93 MB | +| bs=16 | 41408.28 MB | 56778.97 MB | +| bs=30 | 64047.42 MB | failed | + +For DPO, we recommend using zero2 or zero2-cpu. We tested the VRAM consumption on a dummy dataset with 2048 sequence length. + +- 1 H800 GPU + - zero2-cpu, batch size=2, VRAM Usage=49873.90 MB + - zero2-cpu, batch size=4, VRAM Usage=60998.22 MB +- 4 H800 GPUs + - zero2, batch size=4, VRAM Usage=67544.47 MB + +## Inference example + +We support different inference options, including int8 and int4 quantization. +For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference). + +## Attention + +The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance. diff --git a/applications/Chat/examples/community/README.md b/applications/ColossalChat/examples/community/README.md old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/community/README.md rename to applications/ColossalChat/examples/community/README.md diff --git a/applications/Chat/examples/community/peft/README.md b/applications/ColossalChat/examples/community/peft/README.md old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/community/peft/README.md rename to applications/ColossalChat/examples/community/peft/README.md diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/ColossalChat/examples/community/peft/easy_dataset.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/community/peft/easy_dataset.py rename to applications/ColossalChat/examples/community/peft/easy_dataset.py diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/ColossalChat/examples/community/peft/easy_models.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/community/peft/easy_models.py rename to applications/ColossalChat/examples/community/peft/easy_models.py diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/ColossalChat/examples/community/peft/train_peft_prompts.py old mode 100644 new mode 100755 similarity index 98% rename from applications/Chat/examples/community/peft/train_peft_prompts.py rename to applications/ColossalChat/examples/community/peft/train_peft_prompts.py index 1dd9ffcdf..4625f2bda --- a/applications/Chat/examples/community/peft/train_peft_prompts.py +++ b/applications/ColossalChat/examples/community/peft/train_peft_prompts.py @@ -24,7 +24,9 @@ def main(args): if args.strategy == "ddp": strategy = DDPStrategy() elif args.strategy == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5) + strategy = GeminiStrategy( + placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5 + ) elif args.strategy == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/ColossalChat/examples/community/peft/train_peft_sft.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/community/peft/train_peft_sft.py rename to applications/ColossalChat/examples/community/peft/train_peft_sft.py diff --git a/applications/Chat/examples/community/ray/README.md b/applications/ColossalChat/examples/community/ray/README.md old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/community/ray/README.md rename to applications/ColossalChat/examples/community/ray/README.md diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/ColossalChat/examples/community/ray/ray_job_script.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/community/ray/ray_job_script.py rename to applications/ColossalChat/examples/community/ray/ray_job_script.py diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/ColossalChat/examples/community/ray/train_prompts_on_ray.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/community/ray/train_prompts_on_ray.py rename to applications/ColossalChat/examples/community/ray/train_prompts_on_ray.py diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py new file mode 100644 index 000000000..64093f88d --- /dev/null +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Prepare dataset scripts + +Usage: +- For SFT dataset preparation (SFT) +python prepare_dataset.py --type sft \ + --data_input_dirs /PATH/TO/SFT/DATASET \ + --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ + --tokenizer_dir "" \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow \ + +- For prompt dataset preparation (PPO) +python prepare_dataset.py --type prompt \ + --data_input_dirs /PATH/TO/SFT/DATASET \ + --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ + --tokenizer_dir "" \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow \ + +- For Preference dataset preparation (DPO and Reward model training) +python prepare_dataset.py --type preference \ + --data_input_dirs /PATH/TO/SFT/DATASET \ + --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ + --tokenizer_dir "" \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow \ +""" + +import argparse +import json +import math +import os +import random +import time +from multiprocessing import cpu_count + +from coati.dataset import setup_conversation_template, supervised_tokenize_sft, tokenize_prompt_dataset, tokenize_rlhf +from datasets import dataset_dict, load_dataset +from transformers import AutoTokenizer + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--type", + type=str, + required=True, + default=None, + choices=["sft", "prompt", "preference"], + help="Type of dataset, chose from 'sft', 'prompt', 'preference'.", + ) + parser.add_argument( + "--data_input_dirs", + type=str, + required=True, + default=None, + help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.", + ) + parser.add_argument( + "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer" + ) + parser.add_argument( + "--conversation_template_config", + type=str, + default="conversation_template_config", + help="Path \ + to save conversation template config files.", + ) + parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory") + parser.add_argument( + "--data_jsonl_output_dir", + type=str, + default="jsonl_output", + help="Output directory of spliced dataset with jsonl format", + ) + parser.add_argument( + "--data_arrow_output_dir", + type=str, + default="arrow_output", + help="Output directory of spliced dataset with arrow format", + ) + parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence") + parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins") + parser.add_argument( + "--num_samples_per_datafile", + type=int, + default=-1, + help="Number of samples to be generated from each data file. -1 denote all samples.", + ) + args = parser.parse_args() + + if args.num_spliced_dataset_bins >= 100000: + raise ValueError("Too many spliced divisions, must be smaller than 100000") + + assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}" + assert not os.path.exists( + args.data_jsonl_output_dir + ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}" + assert not os.path.exists( + args.data_arrow_output_dir + ), f"Find existed arrow data output dir {args.data_arrow_output_dir}" + os.makedirs(args.data_jsonl_output_dir) + os.makedirs(args.data_arrow_output_dir) + + # Prepare to all input datasets + input_data_paths = [] + input_data_dirs = args.data_input_dirs.split(",") + for ds_dir in input_data_dirs: + ds_dir = os.path.abspath(ds_dir) + assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}" + ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")] + ds_paths = [os.path.join(ds_dir, name) for name in ds_files] + input_data_paths.extend(ds_paths) + + # Prepare to data splitting. + train_splits = [] + split_interval = math.ceil(100 / args.num_spliced_dataset_bins) + for i in range(0, 100, split_interval): + start = i + end = i + split_interval + if end > 100: + end = 100 + train_splits.append(f"train[{start}%:{end}%]") + + # Prepare the tokenizer. + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir, use_fast=False, trust_remote_code=True) + if os.path.exists(args.conversation_template_config): + chat_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8")) + else: + chat_template_config = { + "system_message": "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" + } # Use default system message + if args.type == "preference": + if "stop_ids" not in chat_template_config: + # Ask the user to define stop_ids for PPO training + dummy_messages = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, + {"role": "user", "content": "Who made you?"}, + {"role": "assistant", "content": "I am a chatbot trained by Colossal-AI."}, + ] + dummy_prompt = tokenizer.apply_chat_template(dummy_messages, tokenize=False) + tokenized = tokenizer(dummy_prompt, add_special_tokens=False)["input_ids"] + tokens = tokenizer.convert_ids_to_tokens(tokenized, skip_special_tokens=False) + corresponding_str = [tokenizer.convert_tokens_to_string([token]) for token in tokens] + token_id_mapping = [{"token": s, "id": tokenized[i]} for i, s in enumerate(corresponding_str)] + stop_ids = input( + "For PPO, we recommend to provide stop_ids for the properly stop the generation during roll out stage. " + "stop_ids are the ids of repetitive pattern that indicate the end of the assistant's response. " + "Here is an example of formatted prompt and token-id mapping, you can set stop_ids by entering a list " + "of integers, separate by space, press `Enter` to end. Or you can press `Enter` without input if you are " + "not using PPO or you prefer to not set the stop_ids, in that case, stop_ids will be set to tokenizer.eos_token_id. " + f"\nPrompt:\n{dummy_prompt}\nToken-id Mapping:\n{token_id_mapping}\nstop_ids:" + ) + if stop_ids == "": + chat_template_config["stop_ids"] = [tokenizer.eos_token_id] + else: + try: + chat_template_config["stop_ids"] = [int(s) for s in stop_ids.split()] + except ValueError: + raise ValueError("Invalid input, please provide a list of integers.") + else: + # Set stop_ids to eos_token_id for other dataset types if not exist + if "stop_ids" not in chat_template_config: + chat_template_config["stop_ids"] = [tokenizer.eos_token_id] + + conversation_template = setup_conversation_template( + tokenizer, chat_template_config=chat_template_config, save_path=args.conversation_template_config + ) + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + + list_dataset = load_dataset( + path="json", + data_files=input_data_paths, + cache_dir=os.path.join(args.data_cache_dir, "raw"), + keep_in_memory=False, + split=train_splits, + num_proc=cpu_count(), + ) + + if args.type == "sft": + preparation_function = supervised_tokenize_sft + elif args.type == "prompt": + preparation_function = tokenize_prompt_dataset + elif args.type == "preference": + preparation_function = tokenize_rlhf + else: + raise ValueError("Unknow dataset type. Please choose one from ['sft', 'prompt', 'preference']") + + for index, dataset in enumerate(list_dataset): + assert isinstance(dataset, dataset_dict.Dataset) + if len(dataset) == 0: + # Hack: Skip empty dataset. If dataset contains less than num_of_rank samples, some rank may have empty dataset and leads to error + continue + if args.num_samples_per_datafile > 0: + # limit the number of samples in each dataset + dataset = dataset.select( + random.sample(range(len(dataset)), min(args.num_samples_per_datafile, len(dataset))) + ) + logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.") + dataset = dataset.map( + function=preparation_function, + fn_kwargs={ + "tokenizer": tokenizer, + "conversation_template": conversation_template, + "max_length": args.max_length, + }, + keep_in_memory=False, + num_proc=min(len(dataset), cpu_count()), + ) + + dataset = dataset.filter( + lambda data: data["chosen_input_ids" if args.type == "preference" else "input_ids"] is not None + ) + + # Save each jsonl spliced dataset. + output_index = "0" * (5 - len(str(index))) + str(index) + output_name = f"part-{output_index}" + output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl") + st = time.time() + with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer: + count = 0 + for data_point in dataset: + if count % 500 == 0: + logger.info(f"processing {count} spliced data points for {fp_writer.name}") + count += 1 + fp_writer.write(json.dumps(data_point, ensure_ascii=False) + "\n") + logger.info( + f"Current file {fp_writer.name}; " + f"Data size: {len(dataset)}; " + f"Time cost: {round((time.time() - st) / 60, 6)} minutes." + ) + # Save each arrow spliced dataset + output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name) + logger.info(f"Start to save {output_arrow_path}") + dataset = load_dataset( + path="json", + data_files=[output_jsonl_path], + cache_dir=os.path.join(args.data_cache_dir, "tokenized"), + keep_in_memory=False, + num_proc=cpu_count(), + split="train", + ) + dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(dataset), cpu_count())) + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh new file mode 100755 index 000000000..999d7778b --- /dev/null +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_preference_dataset.sh @@ -0,0 +1,13 @@ +SAVE_DIR="" + +rm -rf $SAVE_DIR/cache +rm -rf $SAVE_DIR/jsonl +rm -rf $SAVE_DIR/arrow + +python prepare_dataset.py --type preference \ + --data_input_dirs "PATH/TO/PREFERENCE/DATA" \ + --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ + --tokenizer_dir "" \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh new file mode 100755 index 000000000..8d3d6c2c2 --- /dev/null +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_prompt_dataset.sh @@ -0,0 +1,13 @@ +SAVE_DIR="" + +rm -rf $SAVE_DIR/cache +rm -rf $SAVE_DIR/jsonl +rm -rf $SAVE_DIR/arrow + +python prepare_dataset.py --type prompt \ + --data_input_dirs /PATH/TO/PROMPT/DATASET \ + --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ + --tokenizer_dir "" \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh new file mode 100755 index 000000000..cf937db2a --- /dev/null +++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_sft_dataset.sh @@ -0,0 +1,13 @@ +SAVE_DIR="" + +rm -rf $SAVE_DIR/cache +rm -rf $SAVE_DIR/jsonl +rm -rf $SAVE_DIR/arrow + +python prepare_dataset.py --type sft \ + --data_input_dirs /PATH/TO/SFT/DATASET \ + --conversation_template_config /PATH/TO/CHAT/TEMPLATE/CONFIG.json \ + --tokenizer_dir "" \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow \ diff --git a/applications/ColossalChat/examples/inference/chatio.py b/applications/ColossalChat/examples/inference/chatio.py new file mode 100755 index 000000000..26784f3a3 --- /dev/null +++ b/applications/ColossalChat/examples/inference/chatio.py @@ -0,0 +1,168 @@ +""" +command line IO utils for chatbot +""" + +import abc +import re + +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import InMemoryHistory +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str): + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream): + """Stream output.""" + + +class SimpleChatIO(ChatIO): + def prompt_for_input(self, role) -> str: + return input(f"{role}: ") + + def prompt_for_output(self, role: str): + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + outputs = outputs.strip() + outputs = outputs.split(" ") + now = len(outputs) - 1 + if now > pre: + print(" ".join(outputs[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(outputs[pre:]), flush=True) + return " ".join(outputs) + + +class RichChatIO(ChatIO): + def __init__(self): + self._prompt_session = PromptSession(history=InMemoryHistory()) + self._completer = WordCompleter(words=["!exit", "!reset"], pattern=re.compile("$")) + self._console = Console() + + def prompt_for_input(self, role) -> str: + self._console.print(f"[bold]{role}:") + prompt_input = self._prompt_session.prompt( + completer=self._completer, + multiline=False, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=None, + ) + self._console.print() + return prompt_input + + def prompt_for_output(self, role: str) -> str: + self._console.print(f"[bold]{role}:") + + def stream_output(self, output_stream): + """Stream output from a role.""" + # Create a Live context for updating the console output + with Live(console=self._console, refresh_per_second=60) as live: + # Read lines from the stream + for outputs in output_stream: + accumulated_text = outputs + if not accumulated_text: + continue + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in accumulated_text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines)) + # Update the Live console output + live.update(markdown) + self._console.print() + return outputs + + +class DummyChatIO(ChatIO): + """ + Dummy ChatIO class for testing + """ + + def __init__(self): + self.roles = [] + self._console = Console() + + def prompt_for_input(self, role) -> str: + self.roles.append(role) + if len(self.roles) == 1: + ret = "Hello" + elif len(self.roles) == 2: + ret = "What's the value of 1+1?" + else: + ret = "exit" + self._console.print(f"[bold]{role}:{ret}") + return ret + + def prompt_for_output(self, role: str) -> str: + self._console.print(f"[bold]{role}:") + + def stream_output(self, output_stream): + """Stream output from a role.""" + # Create a Live context for updating the console output + with Live(console=self._console, refresh_per_second=60) as live: + # Read lines from the stream + for outputs in output_stream: + accumulated_text = outputs + if not accumulated_text: + continue + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in accumulated_text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines)) + # Update the Live console output + live.update(markdown) + self._console.print() + return outputs + + +simple_io = SimpleChatIO() +rich_io = RichChatIO() +dummy_io = DummyChatIO() diff --git a/applications/ColossalChat/examples/inference/inference.py b/applications/ColossalChat/examples/inference/inference.py new file mode 100755 index 000000000..103bd8d95 --- /dev/null +++ b/applications/ColossalChat/examples/inference/inference.py @@ -0,0 +1,195 @@ +import argparse +import json +import os +from typing import Dict + +import torch +from chatio import dummy_io, rich_io, simple_io +from coati.dataset.conversation import setup_conversation_template +from coati.models import generate_streaming +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def get_gpu_memory(max_gpus=None): + """ + Get the available memory for each GPU. + + Args: + max_gpus (int, optional): The maximum number of GPUs to consider. Defaults to None. + + Returns: + list: A list of available memory for each GPU. + """ + gpu_memory = [] + num_gpus = torch.cuda.device_count() if max_gpus is None else min(max_gpus, torch.cuda.device_count()) + + for gpu_id in range(num_gpus): + # Code to get GPU memory goes here + with torch.cuda.device(gpu_id): + device = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device) + total_memory = gpu_properties.total_memory / (1024**3) + allocated_memory = torch.cuda.memory_allocated() / (1024**3) + available_memory = total_memory - allocated_memory + gpu_memory.append(available_memory) + return gpu_memory + + +def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs): + """ + Load the model and tokenizer from the specified paths and move the model to the specified device. + + Args: + model_path (str): The path to the pre-trained model. + tokenizer_path (str): The path to the pre-trained tokenizer. + device (str, optional): The device to move the model to. Defaults to "cuda". + **kwargs: Additional keyword arguments to be passed to the `AutoModelForCausalLM.from_pretrained` function. + + Returns: + tuple: A tuple containing the loaded model and tokenizer. + """ + + model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + tokenizer.pad_token = tokenizer.eos_token + model.to(device) + + return model, tokenizer + + +def _set_default_generate_kwargs(model: PreTrainedModel) -> Dict: + """ + Set default keyword arguments for generation based on the given model. + + Args: + model (PreTrainedModel): The model used for generation. + + Returns: + Dict: A dictionary containing the default keyword arguments for generation. + """ + unwrapped_model = model + new_kwargs = {} + # Use huggingface models method directly + if hasattr(unwrapped_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation + + if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation + return new_kwargs + + +def generation_wrapper(*args, **kwargs): + input_ids = args[1] + tokenizer = args[2] + for output in generate_streaming(*args, **kwargs): + yield tokenizer.batch_decode(output[:, input_ids.size(1) :], skip_special_tokens=True)[0] + + +def main(args): + conversation_template_config = json.load(open(args.conversation_template_config, "r", encoding="utf8")) + + max_new_tokens = args.max_new_tokens + model_max_length = args.model_max_length + model, tokenizer = load_model_and_tokenizer( + args.model_path, args.tokenizer_path or args.model_path, local_files_only=True + ) + + assert max_new_tokens <= model_max_length + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + tokenizer.padding_side = "left" + + model_kwargs = { + "max_new_tokens": max_new_tokens, + # 'early_stopping': True, + # 'top_k': -1, + # 'top_p': 1.0, + # 'temperature': 1.0, + # 'temperature':0.1, + } + round = 1 + + conv = setup_conversation_template(tokenizer, conversation_template_config, args.conversation_template_config) + + while True: + if args.io == "simple": + chat_io = simple_io + elif args.io == "rich": + chat_io = rich_io + elif args.io == "dummy": + chat_io = dummy_io + else: + raise ValueError(f"Unknown io type: {args.io}") + # raw_text = print(">>> Human:", end=" ") + inp = chat_io.prompt_for_input("user") + + if not inp: + print("prompt should not be empty!") + continue + + if inp.strip() == "clear": + conv.clear() + os.system("clear") + continue + + if inp.strip() == "exit": + print("End of chat.") + break + + query_text = inp.strip() + + conv.append_message("user", query_text) + + chat_io.prompt_for_output("assistant") + + prompt = conv.get_prompt(add_generation_prompt=True) + print(prompt + "") + input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to( + torch.cuda.current_device() + ) + default_generate_kwargs = _set_default_generate_kwargs(model) + model_kwargs.update(default_generate_kwargs) + output_stream = generation_wrapper( + model, + input_ids, + tokenizer, + max_length=model_max_length, + temperature=0.7, + early_stopping=True, + stop_token_ids=conversation_template_config["stop_ids"], + **model_kwargs, + ) + + # print(f">>> Assistant:", end=" ") + outputs = chat_io.stream_output(output_stream) + + conv.append_message("assistant", outputs.strip()) + + with open("round.txt", mode="a", encoding="utf-8") as f: + f.write("\n\n" + "=" * 10 + "\n") + f.write(f"round {round}:\n{conv.save_prompt()}\n\n") + f.write("=" * 10 + "\n") + + # print(f">>> Assistant:", end=" ") + + round += 1 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default=None) + parser.add_argument("--tokenizer_path", type=str, default=None) + parser.add_argument("--conversation_template_config", type=str, default=None) + parser.add_argument("--model_max_length", type=int, default=2048) + parser.add_argument("--max_new_tokens", type=int, default=512) + parser.add_argument("--io", type=str, default="rich", choices=["simple", "rich", "dummy"]) + args = parser.parse_args() + main(args) diff --git a/applications/Chat/inference/README.md b/applications/ColossalChat/examples/inference/web_chatbot/README.md old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/inference/README.md rename to applications/ColossalChat/examples/inference/web_chatbot/README.md diff --git a/applications/Chat/inference/locustfile.py b/applications/ColossalChat/examples/inference/web_chatbot/locustfile.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/inference/locustfile.py rename to applications/ColossalChat/examples/inference/web_chatbot/locustfile.py diff --git a/applications/Chat/inference/requirements.txt b/applications/ColossalChat/examples/inference/web_chatbot/requirements.txt old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/inference/requirements.txt rename to applications/ColossalChat/examples/inference/web_chatbot/requirements.txt diff --git a/applications/Chat/inference/server.py b/applications/ColossalChat/examples/inference/web_chatbot/server.py old mode 100644 new mode 100755 similarity index 79% rename from applications/Chat/inference/server.py rename to applications/ColossalChat/examples/inference/web_chatbot/server.py index 7c6a61b9e..aec342802 --- a/applications/Chat/inference/server.py +++ b/applications/ColossalChat/examples/inference/web_chatbot/server.py @@ -5,6 +5,7 @@ from typing import Generator, List, Optional import torch import uvicorn +from coati.models import generate_streaming from coati.quant import llama_load_quant, low_resource_init from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -13,10 +14,9 @@ from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from sse_starlette.sse import EventSourceResponse -from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM -from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, update_model_kwargs_fn -CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." MAX_LEN = 512 running_lock = Lock() @@ -54,20 +54,22 @@ app.add_middleware( ) -def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): - inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} +def generate_streamingly(prompt, max_length, max_new_tokens, top_k, top_p, temperature): + input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"] # TODO(ver217): streaming generation does not support repetition_penalty now model_kwargs = { - "max_generate_tokens": max_new_tokens, + "max_new_tokens": max_new_tokens, "early_stopping": True, "top_k": top_k, "top_p": top_p, "temperature": temperature, - "prepare_inputs_fn": model.prepare_inputs_for_generation, + "prepare_inputs_fn": None, "update_model_kwargs_fn": update_model_kwargs_fn, } is_first_word = True - generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock) + generator = LockedIterator( + generate_streaming(model, input_ids, tokenizer, max_length, **model_kwargs), running_lock + ) for output in generator: output = output.cpu() tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True) @@ -101,9 +103,10 @@ async def event_generator(request: Request, generator: Generator): @app.post("/generate/stream") @limiter.limit("1/second") def generate(data: GenerationTaskReq, request: Request): - prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) + prompt = prompt_processor.preprocess_prompt(data.history) event_source = event_generator( - request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature) + request, + generate_streamingly(prompt, data.max_length, data.max_new_tokens, data.top_k, data.top_p, data.temperature), ) return EventSourceResponse(event_source) @@ -133,6 +136,11 @@ if __name__ == "__main__": "pretrained", help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", ) + parser.add_argument( + "--tokenizer_path", + help="Path to pretrained tokenizer. Can be a local path or a model name from the HuggingFace model hub.", + default=None, + ) parser.add_argument( "--quant", choices=["8bit", "4bit"], @@ -162,26 +170,29 @@ if __name__ == "__main__": if args.quant == "4bit": assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." - tokenizer = AutoTokenizer.from_pretrained(args.pretrained) + if args.tokenizer_path is None: + args.tokenizer_path = args.pretrained + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, local_files_only=True) if args.profanity_file is not None: censored_words = load_json(args.profanity_file) else: censored_words = [] - prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) + prompt_processor = ChatPromptProcessor(censored_words=censored_words) if args.quant == "4bit": with low_resource_init(): - config = LlamaConfig.from_pretrained(args.pretrained) - model = LlamaForCausalLM(config) + config = AutoConfig.from_pretrained(args.pretrained) + model = AutoModelForCausalLM(config) model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size) model.cuda() else: - model = LlamaForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( args.pretrained, load_in_8bit=(args.quant == "8bit"), torch_dtype=torch.float16, device_map="auto", + local_files_only=True, ) if args.quant != "8bit": model.half() # seems to fix bugs for some users. @@ -190,3 +201,8 @@ if __name__ == "__main__": config = uvicorn.Config(app, host=args.http_host, port=args.http_port) server = uvicorn.Server(config=config) server.run() + + +""" +python server.py /home/lcyab/data/models/experiments5/checkpoint/experiment5-2023-10-20-21-53-51/modeling/ --tokenizer_path /mnt/vepfs/lcxyc/leaderboard_models/Colossal-LLaMA-2-7b-base/ +""" diff --git a/applications/ColossalChat/examples/inference/web_chatbot/utils.py b/applications/ColossalChat/examples/inference/web_chatbot/utils.py new file mode 100755 index 000000000..82a1a7255 --- /dev/null +++ b/applications/ColossalChat/examples/inference/web_chatbot/utils.py @@ -0,0 +1,78 @@ +import copy +import json +from threading import Lock +from typing import List + +import jieba +import torch +from coati.dataset.conversation import default_conversation +from pydantic import BaseModel, Field + + +def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict: + if "past_key_values" in outputs: + model_kwargs["past"] = outputs["past_key_values"] + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return model_kwargs + + +class Dialogue(BaseModel): + instruction: str = Field(min_length=1, example="Count up from 1 to 500.") + response: str = Field(example="") + + +class ChatPromptProcessor: + SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt." + + def __init__(self, censored_words: List[str] = []): + self.censored_words = set([word.lower() for word in censored_words]) + self.conv = copy.deepcopy(default_conversation) + + def preprocess_prompt(self, history: List[Dialogue]) -> str: + self.conv.clear() + for round in history: + self.conv.append_message(self.conv.roles[0], round.instruction) + if len(round.instruction) > 0: + self.conv.append_message(self.conv.roles[1], round.response) + return self.conv.get_prompt() + + def postprocess_output(self, output: str) -> str: + return output.strip() + + def has_censored_words(self, text: str) -> bool: + if len(self.censored_words) == 0: + return False + intersection = set(jieba.cut(text.lower())) & self.censored_words + return len(intersection) > 0 + + +class LockedIterator: + def __init__(self, it, lock: Lock) -> None: + self.lock = lock + self.it = iter(it) + + def __iter__(self): + return self + + def __next__(self): + with self.lock: + return next(self.it) + + +def load_json(path: str): + with open(path) as f: + return json.load(f) diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/ColossalChat/examples/ray/1mmt_prompt.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/ray/1mmt_prompt.py rename to applications/ColossalChat/examples/ray/1mmt_prompt.py diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/ColossalChat/examples/ray/mmmt_prompt.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/ray/mmmt_prompt.py rename to applications/ColossalChat/examples/ray/mmmt_prompt.py diff --git a/applications/Chat/examples/ray/requirements.txt b/applications/ColossalChat/examples/ray/requirements.txt old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/examples/ray/requirements.txt rename to applications/ColossalChat/examples/ray/requirements.txt diff --git a/applications/Chat/examples/ray/test_ci.sh b/applications/ColossalChat/examples/ray/test_ci.sh similarity index 100% rename from applications/Chat/examples/ray/test_ci.sh rename to applications/ColossalChat/examples/ray/test_ci.sh diff --git a/applications/Chat/examples/requirements.txt b/applications/ColossalChat/examples/requirements.txt similarity index 51% rename from applications/Chat/examples/requirements.txt rename to applications/ColossalChat/examples/requirements.txt index 5474dfa16..838590f4b 100644 --- a/applications/Chat/examples/requirements.txt +++ b/applications/ColossalChat/examples/requirements.txt @@ -1,3 +1,4 @@ pandas>=1.4.1 sentencepiece -colossalai==0.3.3 +colossalai +prompt_toolkit diff --git a/applications/ColossalChat/examples/training_scripts/hostfile b/applications/ColossalChat/examples/training_scripts/hostfile new file mode 100755 index 000000000..d4118dda9 --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/hostfile @@ -0,0 +1 @@ +10.20.1.82 diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py new file mode 100755 index 000000000..aae034683 --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -0,0 +1,326 @@ +import argparse +import json +import os +import resource +from contextlib import nullcontext + +import torch +from coati.dataset import ( + DataCollatorForPreferenceDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_distributed_dataloader, +) +from coati.models import convert_to_lora_module, disable_dropout +from coati.trainer import DPOTrainer +from coati.utils import load_checkpoint +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + +def train(args): + # check lora compatibility + if "gemini" in args.plugin and args.lora_rank > 0: + raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") + if args.plugin == "gemini_auto" and args.accumulation_steps > 1: + raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") + + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "ddp": + """ + Default torch ddp plugin without any acceleration, for + debugging purpose acceleration, for debugging purpose + """ + plugin = TorchDDPPlugin(find_unused_parameters=True) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="static", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=True, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=0, + parallel_output=False, + precision=args.mixed_precision, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + ref_booster = Booster(plugin=plugin) + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + # Temp Fix: Disable lazy init due to version conflict + # init_ctx = ( + # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + # ) + + init_ctx = nullcontext() + with init_ctx: + if args.use_flash_attn: + model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + else: + model = AutoModelForCausalLM.from_pretrained(args.pretrain) + disable_dropout(model) + if args.enable_reference_model: + if args.use_flash_attn: + ref_model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + else: + ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain) + disable_dropout(ref_model) + else: + ref_model = None + + if args.lora_rank > 0: + model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) + + if args.grad_checkpoint and args.lora_rank == 0: + model.gradient_checkpointing_enable() + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + elif args.lora_rank > 0: + coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") + + # configure tokenizer + tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True) + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + # configure optimizer + optim = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # configure dataset + coordinator.print_on_master(f"Load dataset: {args.dataset}") + mode_map = {"train": "train", "valid": "validation", "test": "test"} + train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map) + data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length) + train_dataloader = setup_distributed_dataloader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + use_tp=args.tp > 1, + ) + + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + if args.warmup_steps is None: + args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optim, + total_steps=args.max_epochs * num_update_steps_per_epoch, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optim, _, train_dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + dataloader=train_dataloader, + ) + if ref_model is not None: + ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader) + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + sampler_start_idx = 0 + start_step = 0 + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}") + booster.load_model(model, args.checkpoint_path) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=booster, + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + ) + assert isinstance(train_dataloader.sampler, StatefulDistributedSampler) + train_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + trainer = DPOTrainer( + actor=model, + ref_model=ref_model, + booster=booster, + actor_optim=optim, + actor_lr_scheduler=lr_scheduler, + tokenizer=tokenizer, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + start_epoch=start_epoch, + save_interval=args.save_interval, + save_dir=args.save_dir, + coordinator=coordinator, + ) + + trainer.fit( + train_preference_dataloader=train_dataloader, + eval_preference_dataloader=None, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() + # save model checkpoint after fitting on only rank0 + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}") + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + help="Choose which plugin to use", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument( + "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" + ) + parser.add_argument("--config_file", type=str, default="config_file", help="Config file") + parser.add_argument("--save_dir", type=str, default="output") + parser.add_argument("--max_length", type=int, default=2048, help="Model max length") + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--enable_reference_model", type=bool, default=True) + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument( + "--lora_train_bias", + type=str, + default="none", + help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers", + ) + parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") + parser.add_argument("--merge_lora_weights", type=bool, default=True) + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + args = parser.parse_args() + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.sh b/applications/ColossalChat/examples/training_scripts/train_dpo.sh new file mode 100755 index 000000000..80fc30c3d --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.sh @@ -0,0 +1,62 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 8 +# export CUDA_VISIBLE_DEVICES=6 + +PROJECT_NAME="dpo" +PARENT_SAVE_DIR="" # Path to a folder to save checkpoints +PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs +PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path + +declare -a dataset=( + YOUR/DATA/DIR/arrow/part-00000 + YOUR/DATA/DIR/arrow/part-00001 + YOUR/DATA/DIR/arrow/part-00002 + YOUR/DATA/DIR/arrow/part-00003 + YOUR/DATA/DIR/arrow/part-00004 + YOUR/DATA/DIR/arrow/part-00005 + YOUR/DATA/DIR/arrow/part-00006 + YOUR/DATA/DIR/arrow/part-00007 + YOUR/DATA/DIR/arrow/part-00008 + YOUR/DATA/DIR/arrow/part-00009 +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" + +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_dpo.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --checkpoint_path $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2" \ + --save_interval 1000 \ + --save_dir $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --max_epochs 1 \ + --accumulation_steps 4 \ + --batch_size 2 \ + --lr 1e-6 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --weight_decay 0.01 \ + --warmup_steps 100 \ + --grad_checkpoint \ + --use_wandb diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.py b/applications/ColossalChat/examples/training_scripts/train_ppo.py new file mode 100755 index 000000000..7c91fa347 --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.py @@ -0,0 +1,506 @@ +import argparse +import json +import os +import resource +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from coati.dataset import ( + DataCollatorForPromptDataset, + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_conversation_template, + setup_distributed_dataloader, +) +from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout +from coati.trainer import PPOTrainer +from coati.utils import load_checkpoint +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + +logger = get_dist_logger() + + +def train(args): + # check lora compatibility + if "gemini" in args.plugin and args.lora_rank > 0: + raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") + if args.plugin == "gemini_auto" and args.accumulation_steps > 1: + raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + # Temp Fix: Disable lazy init due to version conflict + # init_ctx = ( + # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + # ) + + init_ctx = nullcontext() + booster_policy = None + with init_ctx: + if args.use_flash_attn: + actor = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + local_files_only=True, + ) + ref_model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + local_files_only=True, + ) + reward_model = RewardModel( + args.rm_pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + critic = Critic( + args.rm_pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + else: + actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True) + ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True) + reward_model = RewardModel(args.rm_pretrain) + critic = Critic(args.rm_pretrain) + # Disable dropout + disable_dropout(actor) + disable_dropout(critic) + + if args.tp > 1: + if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]: + raise ValueError("Reward model and critic model must have the same architecture") + if reward_model.model.config.architectures[0] == "BloomForCausalLM": + from colossalai.shardformer.policies.bloom import BloomPolicy + + booster_policy = BloomPolicy() + elif reward_model.model.config.architectures[0] == "LlamaForCausalLM": + from colossalai.shardformer.policies.llama import LlamaPolicy + + booster_policy = LlamaPolicy() + elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel": + from colossalai.shardformer.policies.gpt2 import GPT2Policy + + booster_policy = GPT2Policy() + elif reward_model.model.config.architectures[0] == "ChatGLMModel": + from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy + + booster_policy = ChatGLMPolicy() + elif reward_model.model.config.architectures[0] == "OPTForCausalLM": + from colossalai.shardformer.policies.opt import OPTPolicy + + booster_policy = OPTPolicy() + else: + raise ValueError("Unknown model architecture for policy") + + if args.lora_rank > 0: + actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias) + critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias) + + if args.grad_checkpoint and args.lora_rank == 0: + actor.gradient_checkpointing_enable() + critic.model.gradient_checkpointing_enable() + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + elif args.lora_rank > 0: + coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") + + # configure tokenizer + tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True) + if os.path.exists(args.conversation_template_config): + with open(args.conversation_template_config, "r", encoding="utf8") as f: + conversation_template_config = json.load(f) + dist.barrier() + conversation_template = setup_conversation_template( + tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config + ) + stop_ids = conversation_template.stop_ids if len(conversation_template.stop_ids) > 0 else None + else: + raise ValueError("Conversation template config is not provided or incorrect") + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + tokenizer.padding_side = "left" # left padding for generation (online learning) + + # configure generation config + actor.generation_config.update( + pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id + ) + + # configure optimizer + coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}") + actor_optim = HybridAdam( + model_params=actor.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + coordinator.print_on_master(f"setting up optimizer for critic: lr={args.lr}, weight_decay={args.weight_decay}") + critic_optim = HybridAdam( + model_params=critic.parameters(), + lr=args.critic_lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # configure dataset + coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}") + mode_map = {"train": "train", "valid": "validation", "test": "test"} + train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map) + data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len) + train_prompt_dataloader = setup_distributed_dataloader( + dataset=train_prompt_dataset, + batch_size=args.experience_batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + use_tp=args.tp > 1, + ) + + if len(args.ptx_dataset) > 0: + train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) + train_pretrain_dataloader = setup_distributed_dataloader( + dataset=train_ptx_dataset, + batch_size=args.ptx_batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + use_tp=args.tp > 1, + ) + else: + train_pretrain_dataloader = None + + if args.warmup_steps is None: + args.warmup_steps = int(0.025 * args.num_episodes) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + actor_lr_scheduler = CosineAnnealingWarmupLR( + optimizer=actor_optim, + total_steps=args.num_episodes, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + critic_lr_scheduler = CosineAnnealingWarmupLR( + optimizer=critic_optim, + total_steps=args.num_episodes, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "ddp": + """ + Default torch ddp plugin without any acceleration, for + debugging purpose acceleration, for debugging purpose + """ + plugin = TorchDDPPlugin(find_unused_parameters=True) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="static", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=True, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=0, + parallel_output=False, + precision=args.mixed_precision, + ) + custom_plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=0, + parallel_output=False, + precision=args.mixed_precision, + custom_policy=booster_policy, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + if args.plugin != "3d": + custom_plugin = plugin + + actor_booster = Booster(plugin=plugin) + ref_booster = Booster(plugin=plugin) + rm_booster = Booster(plugin=custom_plugin) + critic_booster = Booster(plugin=custom_plugin) + + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost( + model=actor, + optimizer=actor_optim, + lr_scheduler=actor_lr_scheduler, + dataloader=train_prompt_dataloader, + ) + + critic, critic_optim, _, _, critic_lr_scheduler = critic_booster.boost( + model=critic, + optimizer=critic_optim, + lr_scheduler=critic_lr_scheduler, + dataloader=train_prompt_dataloader, + ) + reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader) + ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader) + + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + sampler_start_idx = 0 + start_step = 0 + + if args.rm_checkpoint_path is not None: + if "modeling" in args.rm_checkpoint_path: + rm_booster.load_model(reward_model, args.rm_checkpoint_path) + else: + _, _, _ = load_checkpoint( + load_dir=args.rm_checkpoint_path, + booster=rm_booster, + model=reward_model, + optimizer=None, + lr_scheduler=None, + ) + coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}") + + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + actor_booster.load_model(actor, args.checkpoint_path) + ref_booster.load_model(ref_model, args.checkpoint_path) + coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}") + else: + _, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=actor_booster, + model=actor, + optimizer=actor_optim, + lr_scheduler=actor_lr_scheduler, + ) + _, _, _ = load_checkpoint( + load_dir=args.checkpoint_path, + booster=ref_booster, + model=ref_model, + optimizer=critic_optim, + lr_scheduler=critic_lr_scheduler, + ) + assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler) + train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + if args.critic_checkpoint_path is not None: + if "modeling" in args.critic_checkpoint_path: + critic_booster.load_model(critic, args.critic_checkpoint_path) + else: + _, _, _ = load_checkpoint( + load_dir=args.critic_checkpoint_path, + booster=critic_booster, + model=critic, + optimizer=critic_optim, + lr_scheduler=critic_lr_scheduler, + ) + coordinator.print_on_master(f"Loaded critic checkpoint {args.critic_checkpoint_path}") + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + # configure trainer + trainer = PPOTrainer( + actor_booster, + critic_booster, + actor, + critic, + reward_model, + ref_model, + actor_optim, + critic_optim, + actor_lr_scheduler, + critic_lr_scheduler, + tokenizer=tokenizer, + stop_token_ids=stop_ids, + kl_coef=args.kl_coef, + ptx_coef=args.ptx_coef, + train_batch_size=args.train_batch_size, + buffer_limit=args.num_collect_steps * args.experience_batch_size, + max_length=args.max_length, + max_new_tokens=args.max_seq_len, + use_cache=True, + do_sample=True, + temperature=0.7, + accumulation_steps=args.accumulation_steps, + save_dir=args.save_path, + save_interval=args.save_interval, + top_k=50, + use_tp=args.tp > 1, + offload_inference_models="gemini" not in args.plugin, + coordinator=coordinator, + ) + + trainer.fit( + num_episodes=args.num_episodes, + num_collect_steps=args.num_collect_steps, + num_update_steps=args.num_update_steps, + prompt_dataloader=train_prompt_dataloader, + pretrain_dataloader=train_pretrain_dataloader, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + actor.eval() + critic.eval() + # save model checkpoint after fitting on only rank0 + coordinator.print_on_master("Start saving final actor model checkpoint") + actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}" + ) + coordinator.print_on_master("Start saving final critic model checkpoint") + critic_booster.save_model(critic, os.path.join(trainer.critic_save_dir, "modeling"), shard=True) + coordinator.print_on_master( + f"Saved final critic model checkpoint at episodes {args.num_episodes} at folder {args.save_path}" + ) + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prompt_dataset", nargs="+", default=[]) + parser.add_argument("--ptx_dataset", nargs="+", default=[]) + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], + help="Choose which plugin to use", + ) + parser.add_argument( + "--conversation_template_config", + type=str, + default=None, + help="Path \ + to save conversation template config files.", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--critic_checkpoint_path", type=str, default=None) + parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path") + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--num_collect_steps", type=int, default=2) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--save_interval", type=int, default=1000) + parser.add_argument("--train_batch_size", type=int, default=16) + parser.add_argument("--experience_batch_size", type=int, default=16) + parser.add_argument("--ptx_batch_size", type=int, default=4) + parser.add_argument("--lora_train_bias", type=str, default="none") + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) + parser.add_argument("--lr", type=float, default=9e-6) + parser.add_argument("--critic_lr", type=float, default=9e-6) + parser.add_argument("--kl_coef", type=float, default=0.1) + parser.add_argument("--ptx_coef", type=float, default=0.0) + parser.add_argument("--max_length", type=int, default=2048) + parser.add_argument("--max_seq_len", type=int, default=256) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + args = parser.parse_args() + train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_ppo.sh b/applications/ColossalChat/examples/training_scripts/train_ppo.sh new file mode 100755 index 000000000..91633978e --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_ppo.sh @@ -0,0 +1,82 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 8 + +PROJECT_NAME="ppo" + +PARENT_SAVE_DIR="" # Path to a folder to save checkpoints +PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs +PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT) +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +REWARD_MODEL_PATH="" # local reward model path (from RLHF step 2: Train Reward Model) +CONVERSATION_TEMPLATE_CONFIG_PATH="" # path to the conversation config file + +declare -a prompt_dataset=( + YOUR/PROMPT/DATA/DIR/arrow/part-00000 + YOUR/PROMPT/DATA/DIR/arrow/part-00001 + YOUR/PROMPT/DATA/DIR/arrow/part-00002 + YOUR/PROMPT/DATA/DIR/arrow/part-00003 + YOUR/PROMPT/DATA/DIR/arrow/part-00004 + YOUR/PROMPT/DATA/DIR/arrow/part-00005 + YOUR/PROMPT/DATA/DIR/arrow/part-00006 + YOUR/PROMPT/DATA/DIR/arrow/part-00007 + YOUR/PROMPT/DATA/DIR/arrow/part-00008 + YOUR/PROMPT/DATA/DIR/arrow/part-00009 +) + +declare -a ptx_dataset=( + YOUR/SFT/DATA/DIR/arrow/part-00000 + YOUR/SFT/DATA/DIR/arrow/part-00001 + YOUR/SFT/DATA/DIR/arrow/part-00002 + YOUR/SFT/DATA/DIR/arrow/part-00003 + YOUR/SFT/DATA/DIR/arrow/part-00004 + YOUR/SFT/DATA/DIR/arrow/part-00005 + YOUR/SFT/DATA/DIR/arrow/part-00006 + YOUR/SFT/DATA/DIR/arrow/part-00007 + YOUR/SFT/DATA/DIR/arrow/part-00008 + YOUR/SFT/DATA/DIR/arrow/part-00009 +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" + +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_ppo.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --rm_pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --rm_checkpoint_path $REWARD_MODEL_PATH \ + --prompt_dataset ${prompt_dataset[@]} \ + --conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \ + --ptx_coef 0.0 \ + --plugin "zero2" \ + --save_interval 500 \ + --save_path $SAVE_DIR \ + --num_episodes 2000 \ + --num_collect_steps 2 \ + --num_update_steps 1 \ + --experience_batch_size 4 \ + --train_batch_size 4 \ + --accumulation_steps 2 \ + --lr 9e-6 \ + --mixed_precision "bf16" \ + --grad_clip 0.1\ + --weight_decay 0.01 \ + --warmup_steps 40 \ + --grad_checkpoint \ + --use_wandb diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.py b/applications/ColossalChat/examples/training_scripts/train_rm.py new file mode 100755 index 000000000..a0c710f2b --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_rm.py @@ -0,0 +1,342 @@ +import argparse +import json +import math +import os +import resource +from contextlib import nullcontext + +import torch +from coati.dataset import ( + DataCollatorForPreferenceDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_distributed_dataloader, +) +from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module +from coati.trainer import RewardModelTrainer +from coati.utils import load_checkpoint +from transformers import AutoTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + + +def train(args): + # check lora compatibility + if "gemini" in args.plugin and args.lora_rank > 0: + raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") + if args.plugin == "gemini_auto" and args.accumulation_steps > 1: + raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + # Temp Fix: Disable lazy init due to version conflict + # init_ctx = ( + # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + # ) + + init_ctx = nullcontext() + booster_policy = None + with init_ctx: + if args.use_flash_attn: + model = RewardModel( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + else: + model = RewardModel(args.pretrain) + + if args.tp > 1: + if model.model.config.architectures[0] == "BloomForCausalLM": + from colossalai.shardformer.policies.bloom import BloomPolicy + + booster_policy = BloomPolicy() + elif model.model.config.architectures[0] == "LlamaForCausalLM": + from colossalai.shardformer.policies.llama import LlamaPolicy + + booster_policy = LlamaPolicy() + elif model.model.config.architectures[0] == "GPT2LMHeadModel": + from colossalai.shardformer.policies.gpt2 import GPT2Policy + + booster_policy = GPT2Policy() + elif model.model.config.architectures[0] == "ChatGLMModel": + from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy + + booster_policy = ChatGLMPolicy() + elif model.model.config.architectures[0] == "OPTForCausalLM": + from colossalai.shardformer.policies.opt import OPTPolicy + + booster_policy = OPTPolicy() + else: + raise ValueError("Unknown model architecture for policy") + + if args.lora_rank > 0: + model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "ddp": + """ + Default torch ddp plugin without any acceleration, for + debugging purpose acceleration, for debugging purpose + """ + plugin = TorchDDPPlugin(find_unused_parameters=True) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="static", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=True, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=0, + parallel_output=False, + precision=args.mixed_precision, + custom_policy=booster_policy, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + if args.grad_checkpoint and args.lora_rank == 0: + model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + elif args.lora_rank > 0: + coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") + + # configure tokenizer + tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True) + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + tokenizer.padding_side = "right" + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + # configure loss function + if args.loss_fn == "log_sig": + loss_fn = LogSigLoss() + elif args.loss_fn == "log_exp": + loss_fn = LogExpLoss() + else: + raise ValueError(f'Unsupported loss function "{args.loss_fn}"') + + # configure optimizer + optim = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # configure dataset + coordinator.print_on_master(f"Load dataset: {args.dataset}") + mode_map = {"train": "train", "valid": "validation", "test": "test"} + train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map) + data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length) + train_dataloader = setup_distributed_dataloader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + use_tp=args.tp > 1, + ) + + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + math.ceil(args.max_epochs * num_update_steps_per_epoch) + + if args.warmup_steps is None: + args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optim, + total_steps=args.max_epochs * num_update_steps_per_epoch, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optim, _, train_dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + dataloader=train_dataloader, + ) + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + sampler_start_idx = 0 + start_step = 0 + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}") + booster.load_model(model, args.checkpoint_path) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=booster, + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + ) + assert isinstance(train_dataloader.sampler, StatefulDistributedSampler) + train_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + trainer = RewardModelTrainer( + model, + booster, + optim, + lr_scheduler, + tokenizer, + loss_fn=loss_fn, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + start_epoch=start_epoch, + save_interval=args.save_interval, + save_dir=args.save_dir, + coordinator=coordinator, + ) + + trainer.fit( + train_preference_dataloader=train_dataloader, + eval_preference_dataloader=None, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() + # save model checkpoint after fitting on only rank0 + coordinator.print_on_master("Start saving final model checkpoint") + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_dir}") + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"], + help="Choose which plugin to use", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument( + "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" + ) + parser.add_argument("--config_file", type=str, default="config_file", help="Config file") + parser.add_argument("--save_dir", type=str, default="output") + parser.add_argument("--max_length", type=int, default=2048, help="Model max length") + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function") + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument( + "--lora_train_bias", + type=str, + default="none", + help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers", + ) + parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") + parser.add_argument("--merge_lora_weights", type=bool, default=True) + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + args = parser.parse_args() + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_rm.sh b/applications/ColossalChat/examples/training_scripts/train_rm.sh new file mode 100755 index 000000000..e06d9092f --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_rm.sh @@ -0,0 +1,61 @@ +#!/bin/bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} +set_n_least_used_CUDA_VISIBLE_DEVICES 8 + +PROJECT_NAME="rm" +PARENT_SAVE_DIR="" # Path to a folder to save checkpoints +PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs +PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path + +declare -a dataset=( + YOUR/PREFERENCE/DATA/DIR/arrow/part-00000 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00001 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00002 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00003 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00004 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00005 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00006 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00007 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00008 + YOUR/PREFERENCE/DATA/DIR/arrow/part-00009 +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" + +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 31312 train_rm.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --checkpoint_path /home/yeanbang/data/experiments/rm/hhh_aligh/ckptllama2-rm-2024-01-17-14-43-24/epoch-1_step-1317/modeling \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --dataset ${dataset[@]} \ + --plugin "zero2" \ + --save_interval 1000 \ + --save_dir $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --max_epochs 3 \ + --accumulation_steps 1 \ + --batch_size 8 \ + --lr 5e-6 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --weight_decay 0.01 \ + --warmup_steps 40 \ + --grad_checkpoint \ + --use_wandb diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py new file mode 100755 index 000000000..fcd1a429c --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -0,0 +1,311 @@ +import argparse +import json +import math +import os +import resource +from contextlib import nullcontext + +import torch +from coati.dataset import DataCollatorForSupervisedDataset, load_tokenized_dataset, setup_distributed_dataloader +from coati.models import convert_to_lora_module +from coati.trainer import SFTTrainer +from coati.utils import load_checkpoint +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + + +def train(args): + # check lora compatibility + if "gemini" in args.plugin and args.lora_rank > 0: + raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") + if args.plugin == "gemini_auto" and args.accumulation_steps > 1: + raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") + # ============================== + # Initialize Distributed Training + # ============================== + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "ddp": + """ + Default torch ddp plugin without any acceleration, for + debugging purpose acceleration, for debugging purpose + """ + plugin = TorchDDPPlugin(find_unused_parameters=True) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="static", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=True, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=1, + zero_stage=0, + parallel_output=False, + max_norm=args.grad_clip, + precision=args.mixed_precision, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + # Temp Fix: Disable lazy init due to version conflict + # init_ctx = ( + # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + # ) + + init_ctx = nullcontext() + with init_ctx: + if args.use_flash_attn: + model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + use_flash_attention_2=True, + ) + coordinator.print_on_master(msg="Flash-attention enabled successfully") + else: + model = AutoModelForCausalLM.from_pretrained(args.pretrain) + if args.lora_rank > 0: + model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) + + if args.grad_checkpoint and args.lora_rank == 0: + # lora layers are not supported by gradient checkpointing + model.gradient_checkpointing_enable() + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + elif args.lora_rank > 0: + coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled") + + # configure tokenizer + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True + ) + if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None: + try: + # Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen + tokenizer.pad_token = tokenizer.eos_token + except AttributeError as e: + logger.warning(f"Unable to set pad token to eos token, {str(e)}") + if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: + logger.warning( + "The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them." + ) + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + + coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") + coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}") + + # configure optimizer + optim = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # configure dataset + coordinator.print_on_master( + f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len) + train_dataloader = setup_distributed_dataloader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator, + use_tp=args.tp > 1, + ) + coordinator.print_on_master( + f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + math.ceil(args.max_epochs * num_update_steps_per_epoch) + + if args.warmup_steps is None: + args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optim, + total_steps=args.max_epochs * num_update_steps_per_epoch, + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + # Flash attention will be disabled because it does NOT support fp32. + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optim, _, train_dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + dataloader=train_dataloader, + ) + # model = model.to(get_current_device()) + torch.set_default_dtype(torch.float) + + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + sampler_start_idx = 0 + start_step = 0 + if args.checkpoint_path is not None: + if "modeling" in args.checkpoint_path: + coordinator.print_on_master(f"Continued pretrain from checkpoint {args.checkpoint_path}") + booster.load_model(model, args.checkpoint_path) + else: + coordinator.print_on_master(f"Load model checkpoint from {args.checkpoint_path}") + start_epoch, start_step, sampler_start_idx = load_checkpoint( + load_dir=args.checkpoint_path, + booster=booster, + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + ) + train_dataloader.sampler.set_start_index(start_index=sampler_start_idx) + + coordinator.print_on_master( + f"Loaded checkpoint {args.checkpoint_path} at epoch {start_epoch} step {start_step}" + ) + coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") + + coordinator.print_on_master( + f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + trainer = SFTTrainer( + model=model, + booster=booster, + optim=optim, + lr_scheduler=lr_scheduler, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + start_epoch=start_epoch, + save_interval=args.save_interval, + save_dir=args.save_path, + coordinator=coordinator, + ) + + trainer.fit( + train_dataloader=train_dataloader, + eval_dataloader=None, + log_dir=args.log_dir, + use_wandb=args.use_wandb, + ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() + # save model checkpoint after fitting on only rank0 + coordinator.print_on_master("Start saving final model checkpoint") + + booster.save_model(model, os.path.join(args.save_path, "modeling"), shard=True) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {args.max_epochs} at folder {args.save_path}") + + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument( + "--plugin", + type=str, + default="gemini", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp"], + help="Choose which plugin to use", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--tokenizer_dir", type=str, default=None) + parser.add_argument("--dataset", nargs="+", default=[]) + parser.add_argument( + "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" + ) + parser.add_argument("--save_path", type=str, default="output") + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--max_len", type=int, default=512) + parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument( + "--lora_train_bias", + type=str, + default="none", + help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers", + ) + parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") + parser.add_argument("--merge_lora_weights", type=bool, default=True) + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--config_file", type=str, default="config_file", help="Config file") + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") + parser.add_argument("--use_flash_attn", default=False, action="store_true") + args = parser.parse_args() + os.makedirs(os.path.dirname(args.config_file), exist_ok=True) + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + train(args) diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.sh b/applications/ColossalChat/examples/training_scripts/train_sft.sh new file mode 100755 index 000000000..d5c394377 --- /dev/null +++ b/applications/ColossalChat/examples/training_scripts/train_sft.sh @@ -0,0 +1,59 @@ +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + + +# export CUDA_VISIBLE_DEVICES=4,5,6 +set_n_least_used_CUDA_VISIBLE_DEVICES 4 +PROJECT_NAME="sft" +PARENT_SAVE_DIR="" # Path to a folder to save checkpoints +PARENT_TENSORBOARD_DIR="" # Path to a folder to save logs +PARENT_CONFIG_FILE="" # Path to a folder to save training config logs +PRETRAINED_MODEL_PATH="" # huggingface or local model path +PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path +declare -a dataset=( + YOUR/SFT/DATA/DIR/arrow/part-00000 + YOUR/SFT/DATA/DIR/arrow/part-00001 + YOUR/SFT/DATA/DIR/arrow/part-00002 + YOUR/SFT/DATA/DIR/arrow/part-00003 + YOUR/SFT/DATA/DIR/arrow/part-00004 + YOUR/SFT/DATA/DIR/arrow/part-00005 + YOUR/SFT/DATA/DIR/arrow/part-00006 + YOUR/SFT/DATA/DIR/arrow/part-00007 + YOUR/SFT/DATA/DIR/arrow/part-00008 + YOUR/SFT/DATA/DIR/arrow/part-00009 +) + +TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) +FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}" +SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" +CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json" + +# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size +colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \ + --pretrain $PRETRAINED_MODEL_PATH \ + --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ + --save_interval 4000 \ + --dataset ${dataset[@]} \ + --save_path $SAVE_DIR \ + --config_file $CONFIG_FILE \ + --lora_rank 0 \ + --plugin zero2 \ + --batch_size 8 \ + --max_epochs 1 \ + --accumulation_steps 1 \ + --lr 2e-5 \ + --max_len 2048 \ + --grad_checkpoint \ + --use_wandb diff --git a/applications/Chat/pytest.ini b/applications/ColossalChat/pytest.ini old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/pytest.ini rename to applications/ColossalChat/pytest.ini diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt new file mode 100755 index 000000000..de5f6160e --- /dev/null +++ b/applications/ColossalChat/requirements.txt @@ -0,0 +1,24 @@ +transformers==4.34.1 +huggingface_hub==0.17.3 +tqdm +datasets +loralib +colossalai>=0.3.6 +torch>=1.12.1 +langchain +tokenizers +fastapi +sse_starlette +wandb +sentencepiece +gpustat +packaging +autoflake==2.2.1 +black==23.9.1 +tensorboard +six==1.16.0 +datasets +ninja==1.11.1 +sentencepiece==0.1.99 +flash-attn +tiktoken diff --git a/applications/Chat/setup.py b/applications/ColossalChat/setup.py old mode 100644 new mode 100755 similarity index 97% rename from applications/Chat/setup.py rename to applications/ColossalChat/setup.py index eb44b6203..37503920a --- a/applications/Chat/setup.py +++ b/applications/ColossalChat/setup.py @@ -32,7 +32,7 @@ setup( license="Apache Software License 2.0", url="https://github.com/hpcaitech/Coati", install_requires=fetch_requirements("requirements.txt"), - python_requires=">=3.6", + python_requires=">=3.7", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", diff --git a/applications/Chat/tests/__init__.py b/applications/ColossalChat/tests/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/tests/__init__.py rename to applications/ColossalChat/tests/__init__.py diff --git a/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py new file mode 100644 index 000000000..9f85b4beb --- /dev/null +++ b/applications/ColossalChat/tests/generate_dummy_datasets_for_testing.py @@ -0,0 +1,72 @@ +import argparse +import json +import os + +sft_seed = { + "messages": [ + {"from": "human", "content": "Give three tips for staying healthy."}, + { + "from": "assistant", + "content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.", + }, + ] +} +prompt_seed = { + "messages": [ + {"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."}, + { + "from": "assistant", + "content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.", + }, + ] +} +preference_seed = { + "context": [ + {"from": "human", "content": "What kind of noises did dinosaurs make?"}, + { + "from": "assistant", + "content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be", + }, + {"from": "human", "content": "yes they did"}, + { + "from": "assistant", + "content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.", + }, + {"from": "human", "content": "you cant read"}, + ], + "chosen": [{"from": "assistant", "content": "You can read?"}], + "rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}], +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + required=True, + default=None, + help="The output dir", + ) + parser.add_argument( + "--data_type", + type=str, + required=True, + default=None, + help="The type of data", + ) + args = parser.parse_args() + if args.data_type == "sft": + seed = sft_seed + elif args.data_type == "prompt": + seed = prompt_seed + elif args.data_type == "preference": + seed = preference_seed + else: + raise ValueError(f"Unknown data type {args.data_type}") + + line = json.dumps(seed, ensure_ascii=False) + "\n" + for idx in [1, 2, 3]: + with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f: + for i in range(1000): + f.write(line) + f.write(line) diff --git a/applications/ColossalChat/tests/llama.json b/applications/ColossalChat/tests/llama.json new file mode 100644 index 000000000..482ff9e65 --- /dev/null +++ b/applications/ColossalChat/tests/llama.json @@ -0,0 +1,8 @@ +{ + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant: ' + bos_token }}{% endif %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "stop_ids": [ + 29871, + 2 + ] +} diff --git a/applications/ColossalChat/tests/opt.json b/applications/ColossalChat/tests/opt.json new file mode 100644 index 000000000..6d47666bb --- /dev/null +++ b/applications/ColossalChat/tests/opt.json @@ -0,0 +1,17 @@ +{ + "chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}{{'Human: ' + bos_token + message['content'].strip() + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'].strip() + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + bos_token + message['content'].strip() + eos_token }}{% endif %}{% endfor %}", + "system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + "human_line_start": [ + 2 + ], + "human_line_end": [ + 2 + ], + "assistant_line_start": [ + 2 + ], + "assistant_line_end": [ + 2 + ], + "end_of_system_line_position": 0 +} diff --git a/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl b/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl new file mode 100644 index 000000000..2e11a91c6 --- /dev/null +++ b/applications/ColossalChat/tests/test_data/dpo/test_dpo_data.jsonl @@ -0,0 +1 @@ +{"context": [{"from": "human", "content": "If I were to give you some coordinates, could you tell me where the location is?"}], "chosen": [{"from": "assistant", "content": "Certainly! I can tell you if the given coordinates are within any geographic location, such as a country or city. However, I cannot give accurate directions to any specific location without more information. Are you looking for a more general geographic location, or do you want more precise directions?"}], "rejected": [{"from": "assistant", "content": "Yes, I can determine the location and give you some details about the area. Please provide the latitude and longitude coordinates, and I will return a description of the location based on these coordinates. Please note that precise geographical coordinates are often imprecise, so please provide the best possible coordinates you have available. Here are some examples:\n-The coordinates 41.543554,-71.09471 locate the Boston Commons in Boston, Massachusetts\n-The coordinates 37.42242, -122.083733 locate the Palace of Fine Arts in San Francisco, California \n-The coordinates 25.774145, -80.190262 locate Canal Street in Miami Beach, Florida\n-The coordinates 35. journalisten nl 48.056944, -106.277778 locate the Wyoming State Penitentiary in Rawlins, Wyoming\n-The coordinates 51.507351,-0.127758 locate Tower Bridge in London, England\n-The coordinates 44.885618, -79.422594 locate the Stratosphere Casino, Hotel and Tower in Las Vegas, Nevada"}]} diff --git a/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl b/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl new file mode 100644 index 000000000..21c4d9dc7 --- /dev/null +++ b/applications/ColossalChat/tests/test_data/sft/test_sft_data.jsonl @@ -0,0 +1 @@ +{"messages": [{"from": "human", "content": "what are some pranks with a pen I can do?"}, {"from": "assistant", "content": "Are you looking for practical joke ideas?"}, {"from": "human", "content": "Yes, exactly!"}, {"from": "assistant", "content": "You could try attaching a small buzzer to a pen. When someone tries to use it, they'll be surprised by the sound!"}, {"from": "human", "content": "That's a funny one! Any other ideas?"}, {"from": "assistant", "content": "Another idea is the classic 'ink explosion' prank. Replace the ink of a pen with disappearing ink, and watch the confusion when the ink vanishes from the paper."}]} diff --git a/applications/ColossalChat/tests/test_data_preparation.sh b/applications/ColossalChat/tests/test_data_preparation.sh new file mode 100755 index 000000000..a7689cdc6 --- /dev/null +++ b/applications/ColossalChat/tests/test_data_preparation.sh @@ -0,0 +1,260 @@ +#!/usr/bin/env bash +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +set -xu + +if [ -z "$SFT_DATASET" ]; then + echo "Please set \$SFT_DATASET to the path to sft dataset." + exit 1 +fi + +if [ -z "$PROMPT_DATASET" ]; then + echo "Please set \$PROMPT_DATASET to the path to prompts." + exit 1 +fi + +if [ -z "$PREFERENCE_DATASET" ]; then + echo "Please set \$SFT_DATASET to the path to sft dataset." + exit 1 +fi + +NUM_RETRY=3 +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +BASE_TEMP_DIR=$BASE_DIR/temp +TEST_DIR=$BASE_DIR/tests +EXAMPLES_DIR=$BASE_DIR/examples +DATA_SAVE_PATH=$BASE_TEMP_DIR/rlhf_data +CONFIG_DIR=$BASE_DIR/config +# Skip those tests due to CI tests timeout +MODELS=('llama') + +if [ ! -d "$BASE_TEMP_DIR" ]; then + mkdir "$BASE_TEMP_DIR" + echo "Directory created successfully" +else + echo "Directory already exists" +fi + +if [ ! -d "$DATA_SAVE_PATH" ]; then + mkdir "$DATA_SAVE_PATH" + echo "Directory created successfully" +else + echo "Directory already exists" +fi + + +export OMP_NUM_THREADS=8 + +# install requirements +pip install -r $EXAMPLES_DIR/requirements.txt + +get_data_input_dirs() { + local data_type=$1 + if [[ $data_type == "sft" ]]; then + echo "$SFT_DATASET" + elif [[ $data_type == "prompt" ]]; then + echo "$PROMPT_DATASET" + elif [[ $data_type == "preference" ]]; then + echo "$PREFERENCE_DATASET" + else + echo "Unknown data type $data_type" + exit 1 + fi +} + +get_conversation_template_config() { + local model=$1 + if [[ $model == "llama" ]]; then + echo "$TEST_DIR/llama.json" + elif [[ $model == "opt" ]]; then + echo "$TEST_DIR/opt.json" + else + echo "Unknown model $model" + exit 1 + fi +} + +get_tokenizer_dirs() { + local model=$1 + if [[ $model == "llama" ]]; then + echo "hf-internal-testing/llama-tokenizer" + elif [[ $model == "opt" ]]; then + echo "facebook/opt-125m" + else + echo "Unknown model $model" + exit 1 + fi +} + +random_choice() { + local arr=("$@") + local len=${#arr[@]} + local idx=$((RANDOM % len)) + echo ${arr[$idx]} +} + +echo "Prepare dummy data for testing..." +python $TEST_DIR/generate_dummy_datasets_for_testing.py \ + --data_dir $(get_data_input_dirs sft) \ + --data_type "sft" + +python $TEST_DIR/generate_dummy_datasets_for_testing.py \ + --data_dir $(get_data_input_dirs preference) \ + --data_type "preference" + +python $TEST_DIR/generate_dummy_datasets_for_testing.py \ + --data_dir $(get_data_input_dirs prompt) \ + --data_type "prompt" + +echo "[Test]: testing prepare_preference_dataset.py ..." + +# FIXME: This is a hack to skip tests that are not working +SKIPPED_TESTS=( +) + +# test prepare_preference_dataset +for model in ${MODELS[@]}; do + data_type="preference" + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then + echo "[Test]: Skipped $model-$data_type" + continue + fi + cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache + jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl + arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow + rm -rf $cache_dir + rm -rf $jsonl_dir + rm -rf $arrow_dir + data_input_dirs=$(get_data_input_dirs $data_type) + tokenizer_dir=$(get_tokenizer_dirs $model) + conversation_template=$(get_conversation_template_config $model) + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$data_type, attempt $i" + python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \ + --type preference \ + --data_input_dirs $data_input_dirs \ + --conversation_template_config $conversation_template \ + --tokenizer_dir $tokenizer_dir \ + --data_cache_dir $cache_dir \ + --data_jsonl_output_dir $jsonl_dir \ + --data_arrow_output_dir $arrow_dir \ + --max_length 400 \ + --num_samples_per_datafile 100 \ + --num_spliced_dataset_bins 1 + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$data_type" + exit 1 + fi +done + +echo "[Test]: testing prepare_sft_dataset.py ..." + +# FIXME: This is a hack to skip tests that are not working +SKIPPED_TESTS=( +) + +# test prepare_sft_dataset +for model in ${MODELS[@]}; do + data_type="sft" + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then + echo "[Test]: Skipped $model-$data_type" + continue + fi + cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache + jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl + arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow + data_input_dirs=$(get_data_input_dirs $data_type) + tokenizer_dir=$(get_tokenizer_dirs $model) + conversation_template=$(get_conversation_template_config $model) + for i in $(seq $NUM_RETRY); do + rm -rf $cache_dir + rm -rf $jsonl_dir + rm -rf $arrow_dir + echo "[Test]: $model-$data_type, attempt $i" + python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \ + --type sft \ + --data_input_dirs $data_input_dirs \ + --conversation_template_config $conversation_template \ + --tokenizer_dir $tokenizer_dir \ + --data_cache_dir $cache_dir \ + --data_jsonl_output_dir $jsonl_dir \ + --data_arrow_output_dir $arrow_dir \ + --max_length 400 \ + --num_samples_per_datafile 100 \ + --num_spliced_dataset_bins 1 + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$data_type" + exit 1 + fi +done + +echo "[Test]: testing prepare_prompt_dataset.py ..." + +# FIXME: This is a hack to skip tests that are not working +SKIPPED_TESTS=( +) + +# test prepare_prompt_dataset +for model in ${MODELS[@]}; do + data_type="prompt" + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then + echo "[Test]: Skipped $model-$data_type" + continue + fi + cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache + jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl + arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow + data_input_dirs=$(get_data_input_dirs $data_type) + tokenizer_dir=$(get_tokenizer_dirs $model) + conversation_template=$(get_conversation_template_config $model) + for i in $(seq $NUM_RETRY); do + rm -rf $cache_dir + rm -rf $jsonl_dir + rm -rf $arrow_dir + echo "[Test]: $model-$data_type, attempt $i" + python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \ + --type prompt \ + --data_input_dirs $data_input_dirs \ + --conversation_template_config $conversation_template \ + --tokenizer_dir $tokenizer_dir \ + --data_cache_dir $cache_dir \ + --data_jsonl_output_dir $jsonl_dir \ + --data_arrow_output_dir $arrow_dir \ + --max_length 400 \ + --num_samples_per_datafile 100 \ + --num_spliced_dataset_bins 1 + passed=$? + if [ $passed -eq 0 ]; then + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$data_type" + exit 1 + fi +done diff --git a/applications/ColossalChat/tests/test_lora.py b/applications/ColossalChat/tests/test_lora.py new file mode 100755 index 000000000..4ea9e1a15 --- /dev/null +++ b/applications/ColossalChat/tests/test_lora.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from coati.models import convert_to_lora_module +from torch.utils.data import DataLoader, TensorDataset + + +class SimpleNN(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(SimpleNN, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + +def test_overfit(): + input_size = 1000 + hidden_size = 200 + num_classes = 5 + batch_size = 64 + learning_rate = 0.01 + num_epochs = 200 + + # Synthesized dataset + X = torch.randn(batch_size, input_size) + Y = torch.randint(0, num_classes, (batch_size,)) + + # Convert to DataLoader + dataset = TensorDataset(X, Y) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + # Build and convert model + model = SimpleNN(input_size, hidden_size, num_classes) + weight_to_compare = model.fc1.weight.detach().clone() + model = convert_to_lora_module(model, lora_rank=30) + + # Loss and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + # Train the model + for _ in range(num_epochs): + for i, (inputs, labels) in enumerate(loader): + # Forward pass + outputs = model(inputs) + loss = criterion(outputs, labels) + print(loss) + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Check if model has overfitted + outputs = model(X) + _, predicted = torch.max(outputs.data, 1) + total = labels.size(0) + correct = (predicted == Y).sum().item() + assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset") + assert (weight_to_compare - model.fc1.weight).sum() < 0.01 + + +if __name__ == "__main__": + test_overfit() diff --git a/applications/ColossalChat/tests/test_templating.sh b/applications/ColossalChat/tests/test_templating.sh new file mode 100755 index 000000000..7fefede47 --- /dev/null +++ b/applications/ColossalChat/tests/test_templating.sh @@ -0,0 +1,97 @@ + +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +BASE_TEMP_DIR=$BASE_DIR/temp +EXAMPLES_DIR=$BASE_DIR/examples +TEST_DATA_DIR=$BASE_DIR/tests/test_data +DATA_SAVE_PATH=$BASE_TEMP_DIR/tests +CONFIG_DIR=$BASE_DIR/config + +MODELS=("colossal-llama2" "llama2" "zephyr" "mistral" "chatGLM2" "Qwen" "Vicuna" "Yi") + +get_pretrain() { + local model=$1 + if [[ $model == "colossal-llama2" ]]; then + echo "hpcai-tech/Colossal-LLaMA-2-7b-base" + elif [[ $model == "llama2" ]]; then + echo "hf-internal-testing/llama-tokenizer" + elif [[ $model == "zephyr" ]]; then + echo "HuggingFaceH4/zephyr-7b-beta" + elif [[ $model == "mistral" ]]; then + echo "mistralai/Mistral-7B-Instruct-v0.2" + elif [[ $model == "chatGLM2" ]]; then + echo "THUDM/chatglm2-6b" + elif [[ $model == "Qwen" ]]; then + echo "Qwen/Qwen-7B-Chat" + elif [[ $model == "Vicuna" ]]; then + echo "lmsys/vicuna-7b-v1.5" + elif [[ $model == "Yi" ]]; then + echo "01-ai/Yi-6B-Chat" + else + echo "Unknown model $model" + exit 1 + fi +} + +get_conversation_template_config() { + local model=$1 + echo "$CONFIG_DIR/conversation_template/$model.json" +} + +# Test SFT data Preparation +for model in ${MODELS[@]}; do + echo "Testing SFT data templating for $model" + SAVE_DIR=$DATA_SAVE_PATH/sft/$model + rm -rf $SAVE_DIR/cache + rm -rf $SAVE_DIR/jsonl + rm -rf $SAVE_DIR/arrow + pretrain=$(get_pretrain $model) + conversation_template_config=$(get_conversation_template_config $model) + python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \ + --tokenizer_dir $pretrain \ + --conversation_template_config $conversation_template_config \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow + passed=$? + if [ $passed -ne 0 ]; then + echo "[Test]: Failed in the SFT data templating for $model" + exit 1 + fi + python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/sft/test_sft_data.jsonl \ + --to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type sft + passed=$? + if [ $passed -ne 0 ]; then + echo "[Test]: Failed in the SFT data templating test for $model" + exit 1 + fi +done + + +# Test DPO/PPO data Preparation +for model in ${MODELS[@]}; do + echo "Testing DPO/PPO data templating for $model" + SAVE_DIR=$DATA_SAVE_PATH/dpo/$model + rm -rf $SAVE_DIR/cache + rm -rf $SAVE_DIR/jsonl + rm -rf $SAVE_DIR/arrow + pretrain=$(get_pretrain $model) + conversation_template_config=$(get_conversation_template_config $model) + python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type preference --data_input_dirs $TEST_DATA_DIR/dpo \ + --tokenizer_dir $pretrain \ + --conversation_template_config $conversation_template_config \ + --data_cache_dir $SAVE_DIR/cache \ + --data_jsonl_output_dir $SAVE_DIR/jsonl \ + --data_arrow_output_dir $SAVE_DIR/arrow + passed=$? + if [ $passed -ne 0 ]; then + echo "[Test]: Failed in the DPO data templating for $model" + exit 1 + fi + python $BASE_DIR/tests/verify_chat_data.py --data_source $TEST_DATA_DIR/dpo/test_dpo_data.jsonl \ + --to_verify_file $SAVE_DIR/jsonl/part-00005.jsonl --data_type dpo + passed=$? + if [ $passed -ne 0 ]; then + echo "[Test]: Failed in the DPO data templating test for $model" + exit 1 + fi +done diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh new file mode 100755 index 000000000..5ba490471 --- /dev/null +++ b/applications/ColossalChat/tests/test_train.sh @@ -0,0 +1,397 @@ +#!/usr/bin/env bash + +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +set -xu + + +NUM_RETRY=3 +BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE))) +EXAMPLES_DIR=$BASE_DIR/examples +CONFIG_DIR=$BASE_DIR/config +TEMP_DIR=$BASE_DIR/temp +TEST_DIR=$BASE_DIR/tests +MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models +MODELS_DIR=$TEMP_DIR/models_config +# Skip those tests due to CI tests timeout +MODELS=('llama') +PLUGINS=('gemini' 'gemini_auto' 'zero2' 'zero2_cpu' '3d') +LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally + +export OMP_NUM_THREADS=8 + +get_pretrain() { + local model=$1 + if [[ $model == "llama" ]]; then + echo "nickypro/tinyllama-110M" + elif [[ $model == "opt" ]]; then + echo "facebook/opt-125m" + else + echo "Unknown model $model" + exit 1 + fi +} + +get_tokenizer_dirs() { + local model=$1 + if [[ $model == "llama" ]]; then + echo "hf-internal-testing/llama-tokenizer" + elif [[ $model == "opt" ]]; then + echo "facebook/opt-125m" + else + echo "Unknown model $model" + exit 1 + fi +} + + +get_conversation_template_config() { + local model=$1 + if [[ $model == "llama" ]]; then + echo "$TEST_DIR/llama.json" + elif [[ $model == "opt" ]]; then + echo "$TEST_DIR/opt.json" + else + echo "Unknown model $model" + exit 1 + fi +} + +random_choice() { + local arr=("$@") + local len=${#arr[@]} + local idx=$((RANDOM % len)) + echo ${arr[$idx]} +} + + +echo "[Test]: testing sft ..." + +SKIPPED_TESTS=( + llama-3d-20 # 3d plugin doesn't support lora + llama-gemini_auto-20 # gemini_auto plugin doesn't support lora + llama-gemini-20 # gemini doesn't support lora +) + +GRAD_CKPTS=('--grad_checkpoint') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + if [[ $plugin == "3d" ]]; then + tp='4' + bs='8' + fi + grad_accu='2' + # Check if the plugin is either "gemini_auto" or "gemini" and set grad_accu to '1' + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a dataset=() + for split in $(seq -f "%05g" 0 0); do + dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") + done + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --save_path $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + --lora_rank $lora_rank \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf $MODEL_SAVE_PATH/* + rm -rf $MODELS_DIR/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done + +echo "[Test]: testing reward model ..." + +SKIPPED_TESTS=( + llama-3d-20 # 3d plugin doesn't support lora + llama-gemini_auto-20 # gemini_auto plugin doesn't support lora + llama-gemini-20 # gemini doesn't support lora +) + +GRAD_CKPTS=('--grad_checkpoint') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + if [[ $plugin == "3d" ]]; then + tp='4' + bs='8' + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a dataset=() + for split in $(seq -f "%05g" 0 0); do + dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") + done + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --save_dir $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + --lora_rank $lora_rank \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf $MODEL_SAVE_PATH/* + rm -rf $MODELS_DIR/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done + + +echo "[Test]: testing ppo ..." + + +SKIPPED_TESTS=( + llama-3d-20 # 3d plugin doesn't support lora + llama-gemini-20 # gemini doesn't support lora +) + +GRAD_CKPTS=('--grad_checkpoint') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ $plugin == "gemini_auto" ]]; then + echo "[Test]: Skipped $model-$plugin" + continue # gemini_auto plugin doesn't support generation + fi + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='4' + ebs='8' + conversation_template=$(get_conversation_template_config $model) + if [[ $plugin == "3d" ]]; then + tp='4' + bs='16' + ebs='32' + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + # gemini_auto and gemini doesn't support generation + if [[ $plugin == "gemini_auto" ]]; then + # gemini-auto doesn't support generation + echo "[Test]: Skipped $model-$plugin" + continue + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a prompt_dataset=() + for split in $(seq -f "%05g" 0 0); do + prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split") + done + declare -a ptx_dataset=() + for split in $(seq -f "%05g" 0 0); do + ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") + done + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \ + --pretrain $pretrain \ + --rm_pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --conversation_template_config $conversation_template \ + --prompt_dataset ${prompt_dataset[@]} \ + --ptx_dataset ${ptx_dataset[@]} \ + --ptx_batch_size 1 \ + --ptx_coef 0.2 \ + --save_path $MODEL_SAVE_PATH \ + --lora_rank $lora_rank \ + --plugin $plugin \ + --num_episodes 5 \ + --num_collect_steps 1 \ + --num_update_steps 1 \ + --experience_batch_size $ebs \ + --train_batch_size $bs \ + --accumulation_steps $grad_accu \ + --lr 9e-6 \ + --mixed_precision "bf16" \ + --grad_clip 1.0 \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --max_seq_len 10 \ + --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf $MODEL_SAVE_PATH/* + rm -rf $MODELS_DIR/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done + + +echo "[Test]: testing DPO ..." + +SKIPPED_TESTS=( + llama-3d-20 # 3d plugin doesn't support lora + llama-gemini_auto-20 # gemini_auto plugin doesn't support lora + llama-gemini-20 # gemini doesn't support lora +) +GRAD_CKPTS=('--grad_checkpoint') +for lora_rank in ${LORA_RANK[@]}; do + for model in ${MODELS[@]}; do + for plugin in ${PLUGINS[@]}; do + if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then + echo "[Test]: Skipped $model-$plugin-$lora_rank" + continue + elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + pretrain=$(get_pretrain $model) + tokenizer_dir=$(get_tokenizer_dirs $model) + grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") + tp='1' + bs='2' + if [[ $plugin == "3d" ]]; then + tp='4' + bs='8' + fi + grad_accu='2' + # gemini_auto and gemini doesn't support gradient accumulation + if [[ $plugin == "gemini_auto" ]]; then + grad_accu='1' + fi + # gemini_auto doesn't support generation + # (need to calculate ref_model logits through forwarding in inference mode) + if [[ $plugin == "gemini_auto" ]]; then + echo "[Test]: Skipped $model-$plugin" + continue + fi + for i in $(seq $NUM_RETRY); do + echo "[Test]: $model-$plugin-$lora_rank, attempt $i" + declare -a dataset=() + for split in $(seq -f "%05g" 0 0); do + dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split") + done + colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \ + --pretrain $pretrain \ + --tokenizer_dir $tokenizer_dir \ + --dataset ${dataset[@]} \ + --save_dir $MODEL_SAVE_PATH \ + --config_file $MODELS_DIR/config.jsonl \ + --lora_rank $lora_rank \ + --plugin $plugin \ + --batch_size $bs \ + --max_epochs 1 \ + --accumulation_steps $grad_accu \ + --tp $tp \ + --lr 2e-5 \ + $grad_ckpt \ + --max_len 400 \ + --use_flash_attn + passed=$? + if [ $passed -eq 0 ]; then + rm -rf $MODEL_SAVE_PATH/* + rm -rf $MODELS_DIR/* + break + fi + done + if [ $passed -ne 0 ]; then + echo "[Test]: Failed $model-$plugin-$lora_rank" + exit 1 + fi + done + done +done diff --git a/applications/ColossalChat/tests/verify_chat_data.py b/applications/ColossalChat/tests/verify_chat_data.py new file mode 100644 index 000000000..98ae0c1b2 --- /dev/null +++ b/applications/ColossalChat/tests/verify_chat_data.py @@ -0,0 +1,64 @@ +import argparse +import json + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_source", + type=str, + required=True, + default=None, + help="The raw data file", + ) + parser.add_argument( + "--to_verify_file", + type=str, + required=True, + default=None, + help="The file that contains the data to be verified", + ) + parser.add_argument( + "--data_type", + type=str, + required=True, + default=None, + help="The data type", + ) + args = parser.parse_args() + + # Read data + data = [] + with open(args.data_source, "r", encoding="utf8") as f: + for line in f.readlines(): + data.append(json.loads(line)) + to_verify_data = [] + with open(args.to_verify_file, "r", encoding="utf8") as f: + for line in f.readlines(): + to_verify_data.append(json.loads(line)) + + if args.data_type == "sft": + target_lable = [msg["content"].strip() for msg in data[0]["messages"] if msg["from"] == "assistant"] + target_negative_label = [msg["content"].strip() for msg in data[0]["messages"] if msg["from"] == "human"] + + # Read to verify file + + to_verify_lable = to_verify_data[0]["labels_decode"] + for label in target_lable: + assert any([label in s for s in to_verify_lable]), f"Label {label} not in target label {to_verify_lable}" + for label in target_negative_label: + assert all( + [label not in s for s in to_verify_lable] + ), f"Negative label {label} in target label {to_verify_lable}" + elif args.data_type == "dpo": + chosen_lable = data[0]["chosen"][0]["content"].strip() + rejected_lable = data[0]["rejected"][0]["content"].strip() + + # Read to verify file + to_verify_lable_chosen = to_verify_data[0]["chosen_label_decode"] + to_verify_lable_rejected = to_verify_data[0]["rejected_label_decode"] + assert any( + [chosen_lable in s for s in to_verify_lable_chosen] + ), f"Chosen label {chosen_lable} not in target chosen label {to_verify_lable_chosen}" + assert any( + [rejected_lable in s for s in to_verify_lable_rejected] + ), f"Rejected label {rejected_lable} not in target rejected label {to_verify_lable_chosen}" diff --git a/applications/Chat/version.txt b/applications/ColossalChat/version.txt old mode 100644 new mode 100755 similarity index 100% rename from applications/Chat/version.txt rename to applications/ColossalChat/version.txt From e614aa34f324d6583e58f107048922071316344b Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 1 Apr 2024 11:34:58 +0800 Subject: [PATCH 37/43] [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests --- .../colossal_moe/models/mixtral_policy.py | 8 +- .../booster/plugin/hybrid_parallel_plugin.py | 5 +- colossalai/inference/engine/policies/bloom.py | 4 +- .../inference/engine/policies/chatglm2.py | 4 +- colossalai/inference/engine/policies/llama.py | 4 +- colossalai/pipeline/stage_manager.py | 87 ++++++++++++++++++- colossalai/shardformer/__init__.py | 2 +- colossalai/shardformer/modeling/llama.py | 14 ++- .../shardformer/policies/base_policy.py | 49 +---------- colossalai/shardformer/policies/bert.py | 32 ++----- colossalai/shardformer/policies/bloom.py | 8 +- colossalai/shardformer/policies/chatglm2.py | 8 +- colossalai/shardformer/policies/falcon.py | 8 +- colossalai/shardformer/policies/gpt2.py | 30 ++----- colossalai/shardformer/policies/gptj.py | 8 +- colossalai/shardformer/policies/llama.py | 33 ++----- colossalai/shardformer/policies/opt.py | 8 +- colossalai/shardformer/policies/t5.py | 19 ++-- colossalai/shardformer/policies/vit.py | 8 +- colossalai/shardformer/policies/whisper.py | 17 ++-- colossalai/shardformer/shard/__init__.py | 3 +- .../shardformer/shard/grad_ckpt_config.py | 87 +++++++++++++++++++ colossalai/shardformer/shard/shard_config.py | 4 + .../language/openmoe/model/openmoe_policy.py | 85 +++++++++--------- tests/kit/model_zoo/transformers/llama.py | 6 +- .../test_t5_pipeline_utils.py | 25 ++++++ .../test_whisper_pipeline_utils.py | 25 ++++++ .../test_model/test_shard_llama.py | 18 ++++ 28 files changed, 396 insertions(+), 213 deletions(-) create mode 100644 colossalai/shardformer/shard/grad_ckpt_config.py diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py index 23ffbf5d3..c01e02c49 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -109,8 +109,8 @@ class MixtralPolicy(Policy): else: module = self.model.model - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls @@ -129,10 +129,10 @@ class MixtralPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index f51cb060c..eba7d1c1f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,7 +26,7 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor @@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ @@ -969,6 +970,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, ) -> None: super().__init__() @@ -1043,6 +1045,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, + gradient_checkpoint_config=gradient_checkpoint_config, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/inference/engine/policies/bloom.py b/colossalai/inference/engine/policies/bloom.py index f35b50189..5bc47c3c1 100644 --- a/colossalai/inference/engine/policies/bloom.py +++ b/colossalai/inference/engine/policies/bloom.py @@ -114,12 +114,12 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings_layernorm) held_layers.append(self.model.lm_head) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/inference/engine/policies/chatglm2.py b/colossalai/inference/engine/policies/chatglm2.py index 3e1d94f47..c7c6f3b92 100644 --- a/colossalai/inference/engine/policies/chatglm2.py +++ b/colossalai/inference/engine/policies/chatglm2.py @@ -69,11 +69,11 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) if stage_manager.is_first_stage(): held_layers.append(module.embedding) held_layers.append(module.output_layer) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): if module.encoder.post_layer_norm: diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py index 11517d7e8..a57a4e50c 100644 --- a/colossalai/inference/engine/policies/llama.py +++ b/colossalai/inference/engine/policies/llama.py @@ -194,11 +194,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) held_layers.append(self.model.lm_head) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index c8f904208..b0556669b 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -1,6 +1,7 @@ import contextlib -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union +import numpy as np import torch.distributed as dist from torch.distributed import ProcessGroup @@ -29,6 +30,8 @@ class PipelineStageManager: ) -> None: assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" + self.num_layers_per_stage = None + self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None @@ -69,6 +72,88 @@ class PipelineStageManager: # for shardformer, hold model chunk id self.model_chunk_id: Optional[int] = None + @property + def control_distribute_layers(self) -> bool: + return self.num_layers_per_stage is not None + + def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None: + """Set the distribution configuration. + This allows user to customize the number of layers for each stage. + + Args: + num_model_layers (int): Number of layers in the model. + num_layers_per_stage (List[int]): Number of layers for each stage. + """ + assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage]) + assert sum(num_layers_per_stage) == num_model_layers + assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1) + self.num_model_layers = num_model_layers + self.num_layers_per_stage = num_layers_per_stage + + def distribute_layers( + self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None + ) -> List[int]: + """Divide layers into stages""" + num_stages = self.num_stages if num_stages is None else num_stages + num_model_chunks = ( + (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks + ) + + if self.control_distribute_layers: + assert num_layers == self.num_model_layers + return self.num_layers_per_stage + + else: + quotient = num_layers // (num_stages * num_model_chunks) + remainder = num_layers % (num_stages * num_model_chunks) + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages * num_model_chunks + + # deal with the rest layers + if remainder > 0: + start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage + + def get_stage_index( + self, + layers_per_stage: List[int], + stage: Optional[int] = None, + num_model_chunks: Optional[int] = None, + num_stages: Optional[int] = None, + ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: + """ + Get the start index and end index of layers for each stage. + + Args: + layers_per_stage (List[int]): number of layers for each stage + stage (int): the stage index + num_stages (int): number of stages + num_model_chunks (int): number of model chunks + + Returns: + - Tuple[int, int]: the start index and end index of this stage + - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk + + """ + stage = self.stage if stage is None else stage + num_model_chunks = ( + (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks + ) + num_stages = self.num_stages if num_stages is None else num_stages + + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + stage_indices = [] + for model_chunk in range(num_model_chunks): + start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] + end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] + stage_indices.append([start_idx, end_idx]) + + return stage_indices[0] if num_model_chunks == 1 else stage_indices + def is_first_stage(self, ignore_chunk: bool = False) -> bool: """Is the current stage the first stage. diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 77c2af8d1..234e71317 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1 @@ -from .shard import ShardConfig, ShardFormer +from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 29dc8200f..eb421c92b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -138,13 +138,25 @@ class LlamaPipelineForwards: next_decoder_cache = () if use_cache else None start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_layers=end_idx - start_idx, + model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, + ) + assert num_ckpt_layers <= end_idx - start_idx + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: + if idx - start_idx < num_ckpt_layers: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 762e75481..d67ab0a3c 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch.nn as nn from torch import Tensor from torch.nn import Module @@ -196,49 +195,3 @@ class Policy(ABC): List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] """ return [] - - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - """Divide layers into stages""" - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_stages // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - - def get_stage_index( - self, - layers_per_stage: List[int], - stage: int, - num_model_chunks: int = 1, - num_stages: int = 0, - ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: - """ - Get the start index and end index of layers for each stage. - - Args: - layers_per_stage (List[int]): number of layers for each stage - stage (int): the stage index - num_stages (int): number of stages - num_model_chunks (int): number of model chunks - - Returns: - - Tuple[int, int]: the start index and end index of this stage - - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk - - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - stage_indices = [] - for model_chunk in range(num_model_chunks): - start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] - end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] - stage_indices.append([start_idx, end_idx]) - - return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 4d50a3c99..cd7bdcdd6 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -279,16 +279,8 @@ class BertPolicy(Policy): module = self.model.bert if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.encoder.layer), - stage_manager.num_stages * stage_manager.num_model_chunks, - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -298,8 +290,8 @@ class BertPolicy(Policy): } else: - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -324,16 +316,8 @@ class BertPolicy(Policy): held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.encoder.layer), - stage_manager.num_stages * stage_manager.num_model_chunks, - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embeddings) for start_idx, end_idx in stage_indices: @@ -342,10 +326,10 @@ class BertPolicy(Policy): held_layers.append(module.pooler) else: - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.pooler) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index e4714c8c1..55b69d5f0 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -203,8 +203,8 @@ class BloomPolicy(Policy): else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -226,11 +226,11 @@ class BloomPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings_layernorm) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index cbe6254d1..0830d85f1 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -179,10 +179,10 @@ class ChatGLMPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) if stage_manager.is_first_stage(): held_layers.append(module.embedding) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): if module.encoder.post_layer_norm: @@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy): else: module = self.model.transformer - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 16bbc3f23..fe61c406f 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -161,8 +161,8 @@ class FalconPolicy(Policy): else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -181,10 +181,10 @@ class FalconPolicy(Policy): module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d1a8c9dce..4bcac3951 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -185,15 +185,8 @@ class GPT2Policy(Policy): held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.wte) held_layers.append(module.wpe) @@ -203,12 +196,12 @@ class GPT2Policy(Policy): if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.ln_f) else: - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.wte) held_layers.append(module.wpe) held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) @@ -226,15 +219,8 @@ class GPT2Policy(Policy): module = self.model.transformer if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -243,8 +229,8 @@ class GPT2Policy(Policy): ) } else: - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index b24443298..eab4c214a 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -179,11 +179,11 @@ class GPTJPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.wte) held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) @@ -200,8 +200,8 @@ class GPTJPolicy(Policy): else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index daa7708c8..18d79f84a 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -164,30 +164,20 @@ class LlamaPolicy(Policy): module = self.model.model if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) } else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config ) } - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=model_cls - ) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -204,15 +194,8 @@ class LlamaPolicy(Policy): held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: @@ -221,10 +204,10 @@ class LlamaPolicy(Policy): held_layers.append(module.norm) else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 683f3a9d5..98e584be8 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -186,12 +186,12 @@ class OPTPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) held_layers.append(module.embed_positions) held_layers.append(module.project_in) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.final_layer_norm) @@ -208,8 +208,8 @@ class OPTPolicy(Policy): else: module = self.model.model.decoder - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index f5f701dc0..0c8ec15fa 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -251,6 +251,8 @@ class T5BasePolicy(Policy): Return the layer distribution as a list and the starting stage of decoder. If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "Pipeline stage manager is not set." # number of encoder layers must be a positive integer if num_encoder_layers <= 0: @@ -262,7 +264,7 @@ class T5BasePolicy(Policy): # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return self.distribute_layers(num_encoder_layers, num_stages), num_stages + return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -273,21 +275,26 @@ class T5BasePolicy(Policy): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages def get_t5_stage_index( self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int - ) -> Tuple[bool, int, int]: + ) -> Tuple[int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "Pipeline stage manager is not set." + if stage < decoder_starting_stage: - return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return stage_manager.get_stage_index( + layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage + ) def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index b0f224e22..905398c4d 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -134,10 +134,10 @@ class ViTPolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) return held_layers @@ -149,8 +149,8 @@ class ViTPolicy(Policy): else: module = self.model.vit - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 480a4beea..c63f6d1cc 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -300,6 +300,8 @@ class WhisperPolicy(Policy): Return the layer distribution as a list and the starting stage of decoder. If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "pipeline_stage_manager is None" # number of encoder layers must be a positive integer if num_encoder_layers <= 0: @@ -311,7 +313,7 @@ class WhisperPolicy(Policy): # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return self.distribute_layers(num_encoder_layers, num_stages), num_stages + return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -322,21 +324,24 @@ class WhisperPolicy(Policy): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages def get_whisper_stage_index( self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int - ) -> Tuple[bool, int, int]: + ) -> Tuple[int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "pipeline_stage_manager is None" + if stage < decoder_starting_stage: - return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return self.get_stage_index( + return stage_manager.get_stage_index( layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage, ) diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index acf8a95a4..dff2118c1 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -1,5 +1,6 @@ +from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig from .shard_config import ShardConfig from .sharder import ModelSharder from .shardformer import ShardFormer -__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"] +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"] diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py new file mode 100644 index 000000000..9c6c2b54e --- /dev/null +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class GradientCheckpointConfig: + gradient_checkpointing_ratio: float = 0.0 + + def get_num_ckpt_layers(self, num_layers: int) -> int: + return int(self.gradient_checkpointing_ratio * num_layers) + + +@dataclass +class PipelineGradientCheckpointConfig(GradientCheckpointConfig): + r""" + The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism. + Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism. + Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details. + + It provides the following features: + 1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing. + 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. + + """ + """ + Args: + gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. + num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check. + num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check. + num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check. + num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None. + + Example 1: + num_stages = 8 + num_layers = 80 + num_model_chunks = 1 + num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] + num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0] + + Example 2: + num_stages = 4 + num_layers = 80 + num_model_chunks = 2 + num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] + # device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers + ... + + """ + num_stages: Optional[int] = None + num_model_chunks: Optional[int] = None + num_model_layers: Optional[int] = None + num_ckpt_layers_per_stage: Optional[List[int]] = None + + def __post_init__(self): + if self._enable_gradient_checkpointing_ratio: + if not (0 <= self.gradient_checkpointing_ratio <= 1): + raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") + + if self._enable_customized_ckpt_layers_per_stage: + assert ( + self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None + ) + assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks + assert all( + [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage] + ) + self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers + + @property + def _enable_gradient_checkpointing_ratio(self) -> bool: + return self.gradient_checkpointing_ratio is not None + + @property + def _enable_customized_ckpt_layers_per_stage(self) -> bool: + return self.num_ckpt_layers_per_stage is not None + + def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: + if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage: + raise RuntimeError("No checkpointed layers information is provided") + + if self._enable_customized_ckpt_layers_per_stage: + assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks + num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] + assert num_ckpt_layers <= num_layers + return num_ckpt_layers + else: + return int(self.gradient_checkpointing_ratio * num_layers) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index da27341d9..646b61193 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -6,6 +6,8 @@ from torch.distributed import ProcessGroup from colossalai.pipeline.stage_manager import PipelineStageManager +from .grad_ckpt_config import GradientCheckpointConfig + __all__ = ["ShardConfig"] @@ -23,6 +25,7 @@ class ShardConfig: enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. + gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -35,6 +38,7 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 66a42e017..8ef07bdb9 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Optional, Union @@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"] class OpenMoePolicy(Policy): - def config_sanity_check(self): pass @@ -43,7 +41,8 @@ class OpenMoePolicy(Policy): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False raise NotImplementedError( - "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag." + ) if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") @@ -97,8 +96,8 @@ class OpenMoePolicy(Policy): else: module = self.model.model - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls @@ -117,10 +116,10 @@ class OpenMoePolicy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) @@ -143,7 +142,6 @@ class OpenMoePolicy(Policy): class OpenMoeModelPolicy(OpenMoePolicy): - def __init__(self) -> None: super().__init__() @@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy): class OpenMoeForCausalLMPolicy(OpenMoePolicy): - def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - OpenMoeForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ + OpenMoeForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True), ) - ]) + ] + ) } policy.update(new_item) @@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy): def get_shared_params(self) -> List[Dict[int, Tensor]]: llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: - if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) - and self.pipeline_stage_manager.num_stages > 1): + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): # tie weights - return [{ - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, - }] + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] return [] @@ -247,12 +249,13 @@ class OpenMoePipelineForwards: logger = logging.get_logger(__name__) - output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): @@ -320,7 +323,8 @@ class OpenMoePipelineForwards: if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -333,12 +337,11 @@ class OpenMoePipelineForwards: if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = (past_key_values[idx] if past_key_values is not None else None) + past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -384,14 +387,16 @@ class OpenMoePipelineForwards: router_z_loss = past_router_z_loss + router_z_loss if stage_manager.is_last_stage(): - return tuple([ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - router_aux_loss, - router_z_loss, - ]) + return tuple( + [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + router_aux_loss, + router_z_loss, + ] + ) # always return dict for imediate stage return { "hidden_states": hidden_states, @@ -445,10 +450,11 @@ class OpenMoePipelineForwards: "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" logger = logging.get_logger(__name__) - output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: @@ -504,7 +510,6 @@ class OpenMoePipelineForwards: if chunk_head == True: def create_custom_forward(module): - def custom_forward(*inputs): logits = module(inputs[0]) logits = logits.float() @@ -522,8 +527,8 @@ class OpenMoePipelineForwards: for batch_idx in range(hidden_states.shape[0]): loss = loss + torch.utils.checkpoint.checkpoint( create_custom_forward(self.lm_head), - hidden_states[batch_idx:batch_idx + 1, :], - labels[batch_idx:batch_idx + 1, :], + hidden_states[batch_idx : batch_idx + 1, :], + labels[batch_idx : batch_idx + 1, :], ) logits = None else: diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 473064270..9f801e0cc 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -49,9 +49,9 @@ if HAS_LLAMA: loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( - num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, + num_hidden_layers=8, + hidden_size=32, + intermediate_size=64, num_attention_heads=4, max_position_embeddings=128, num_labels=16, diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 4ba67225f..1b7b0073f 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -1,4 +1,23 @@ +import random + +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.t5 import T5BasePolicy +from colossalai.shardformer.shard.shard_config import ShardConfig + + +class _ShardConfig(ShardConfig): + def __post_init__(self): + pass + + +class _PipelineStageManager(PipelineStageManager): + def __init__(self): + self.is_interleave = False + self.num_layers_per_stage = None + + @property + def num_stages(self): + return random.randint(5, 10) def test_t5_pipeline_distribution(): @@ -10,7 +29,10 @@ def test_t5_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = T5BasePolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): _, decoder_starting_stage = policy.distribute_t5_layers( test_dict["num_encoder_layers"][i], @@ -35,7 +57,10 @@ def test_t5_pipeline_layers(): } for i in range(num_test_cases): + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = T5BasePolicy() + policy.set_shard_config(shard_config) layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers( test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index 0500e46e8..9f8c1ad32 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -1,4 +1,23 @@ +import random + +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.whisper import WhisperPolicy +from colossalai.shardformer.shard.shard_config import ShardConfig + + +class _ShardConfig(ShardConfig): + def __post_init__(self): + pass + + +class _PipelineStageManager(PipelineStageManager): + def __init__(self): + self.is_interleave = False + self.num_layers_per_stage = None + + @property + def num_stages(self): + return random.randint(5, 10) def test_whisper_pipeline_distribution(): @@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = WhisperPolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): _, decoder_starting_stage = policy.distribute_whisper_layers( test_dict["num_encoder_layers"][i], @@ -34,7 +56,10 @@ def test_whisper_pipeline_layers(): ], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = WhisperPolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers( test_dict["num_encoder_layers"][i], diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 126ff23a9..55858cbd4 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,6 +5,7 @@ import torch import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -24,9 +25,13 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False) org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config ) + if enable_gradient_checkpointing: + org_model.gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable() org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster @@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), }, { "tp_size": 1, @@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 4, "use_lazy_init": False, "precision": "fp32", + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] + ), }, { "tp_size": 4, @@ -189,6 +200,13 @@ def run_llama_test(test_config): "precision": "fp16", "zero_stage": 1, "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_stages=2, + num_model_chunks=2, + num_model_layers=8, + num_ckpt_layers_per_stage=[0, 1, 2, 2], + ), }, ], ) From 7e0ec5a85c73fcc5666b9d218e43865141587dde Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 2 Apr 2024 20:11:18 +0800 Subject: [PATCH 38/43] fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz --- colossalai/shardformer/shard/shard_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 646b61193..ce78a7e94 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -74,8 +74,10 @@ class ShardConfig: self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True - self.enable_sequence_parallelism = True - self.enable_sequence_overlap = True + # This can cause non-in-place param sharding when used without ZeRO. + # It may also slow down training when seq len is small. Plz enable manually. + # self.enable_sequence_parallelism = True + # self.enable_sequence_overlap = True def _infer(self): """ From 8e412a548e5366d1c42bcf386bd185091bd0c280 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 3 Apr 2024 17:15:47 +0800 Subject: [PATCH 39/43] [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 --- .../booster/plugin/hybrid_parallel_plugin.py | 89 +++- .../plugin/moe_hybrid_parallel_plugin.py | 4 + colossalai/cluster/process_group_mesh.py | 37 +- colossalai/shardformer/layer/__init__.py | 2 + colossalai/shardformer/layer/_operation.py | 389 ++++++++++++++++-- colossalai/shardformer/layer/linear.py | 49 ++- .../shardformer/layer/qkv_fused_linear.py | 44 +- colossalai/shardformer/layer/utils.py | 26 +- colossalai/shardformer/modeling/bert.py | 20 +- colossalai/shardformer/modeling/bloom.py | 18 +- colossalai/shardformer/modeling/chatglm2.py | 22 +- colossalai/shardformer/modeling/gpt2.py | 30 +- colossalai/shardformer/modeling/llama.py | 301 +++++++++++++- colossalai/shardformer/policies/bert.py | 30 +- colossalai/shardformer/policies/bloom.py | 27 +- colossalai/shardformer/policies/chatglm2.py | 22 +- colossalai/shardformer/policies/gpt2.py | 40 +- colossalai/shardformer/policies/llama.py | 87 +++- colossalai/shardformer/shard/shard_config.py | 63 ++- colossalai/zero/low_level/low_level_optim.py | 2 +- tests/kit/model_zoo/transformers/gpt.py | 46 ++- tests/kit/model_zoo/transformers/llama.py | 14 +- .../test_gemini_checkpoint_io.py | 17 +- tests/test_cluster/test_process_group_mesh.py | 30 ++ .../test_gpt2_qkv_fused_linear_1d.py | 27 +- .../test_layer/test_linear_1d.py | 42 +- .../test_layer/test_sequence_parallel.py | 178 ++++++++ tests/test_shardformer/test_model/_utils.py | 40 +- .../test_model/test_shard_bert.py | 23 +- .../test_model/test_shard_bloom.py | 22 + .../test_model/test_shard_chatglm2.py | 22 + .../test_model/test_shard_gpt2.py | 22 + .../test_model/test_shard_llama.py | 101 +++++ 33 files changed, 1630 insertions(+), 256 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_sequence_parallel.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index eba7d1c1f..29cec7cfd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -34,7 +34,8 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3 +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} @@ -53,6 +54,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): shard_config: ShardConfig, dp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, use_ddp: bool, ddp_config: dict, custom_policy: Policy, @@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.shard_config = shard_config self.dp_group = dp_group self.tp_group = tp_group + self.sp_group = sp_group self.use_dpp = use_ddp self.require_grad_sync = True @@ -168,13 +171,24 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): Returns: None """ - if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: + + if self.shard_config.enable_sequence_parallelism: + if self.shard_config.sequence_parallelism_mode == "all_to_all": + return + + if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized + # across the tensor parallelism group. + group = self.tp_group + else: + raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") + if grads is not None: # Synchronize provided gradient tensors across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads) + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads) else: # Synchronize gradients from the model across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module) + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module) def forward(self, *args, **kwargs): if self.convert_fn is not None: @@ -727,10 +741,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # Get all working gradients and gradients to be synchronized. all_working_grads = _get_all_working_grads() grads_to_sync = _get_grads_to_sync(all_working_grads) - if self.require_grad_sync and grads_to_sync is not None: # Synchronize sequence parallelism gradients if required. - SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync) + SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) else: return @@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase): Args: tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + sp_size (int): The size of sequence parallelism. precision (str, optional): Specifies the precision of parameters during training. Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. Defaults to 'fp16'. @@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. @@ -938,6 +953,7 @@ class HybridParallelPlugin(PipelinePluginBase): self, tp_size: int, pp_size: int, + sp_size: int = None, precision: str = "fp16", zero_stage: int = 0, enable_all_optimization: bool = False, @@ -945,6 +961,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention: bool = False, enable_jit_fused: bool = False, enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, enable_sequence_overlap: bool = False, parallel_output: bool = True, num_microbatches: Optional[int] = None, @@ -976,14 +993,41 @@ class HybridParallelPlugin(PipelinePluginBase): super().__init__() assert ( dist.get_world_size() % (tp_size * pp_size) == 0 - ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" if enable_sequence_parallelism: - assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" + self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1" + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + warnings.warn( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all"]: + assert ( + tp_size == 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism" + assert ( + pp_size == 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism" + self.sp_size = dist.get_world_size() if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size) + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert ( + sp_size == 1 or sp_size is None + ), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True" + self.sp_size = 1 self.tp_size = tp_size self.pp_size = pp_size - self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload @@ -992,7 +1036,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None self.custom_policy = custom_policy @@ -1033,9 +1077,14 @@ class HybridParallelPlugin(PipelinePluginBase): self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: + self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, enable_all_optimization=self.enable_all_optimization, @@ -1043,6 +1092,7 @@ class HybridParallelPlugin(PipelinePluginBase): enable_flash_attention=self.enable_flash_attention, enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, gradient_checkpoint_config=gradient_checkpoint_config, @@ -1113,13 +1163,23 @@ class HybridParallelPlugin(PipelinePluginBase): ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): - use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 + and self.pp_size == 1 + and self.enable_sequence_parallelism + and self.sequence_parallelism_mode == "all_to_all" + ) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS]) + else: + dp_group = self.dp_group model = HybridParallelModule( model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.dp_group, + dp_group=dp_group, tp_group=self.tp_group, + sp_group=self.sp_group, use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, @@ -1149,7 +1209,8 @@ class HybridParallelPlugin(PipelinePluginBase): tp_process_group=self.tp_group, ) else: - if self.dp_size == 1: + zero_dp_size = dist.get_world_size(dp_group) + if zero_dp_size == 1: warnings.warn( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you are not intended to use cpu_offload, please consider set zero_stage=0." @@ -1161,7 +1222,7 @@ class HybridParallelPlugin(PipelinePluginBase): model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=self.dp_group, + dp_process_group=dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, verbose=True, diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index ae372dd03..83888e506 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -254,6 +254,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + # TODO: Currently moe only support partially sequence parallel + self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, @@ -365,6 +368,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): shard_config=self.shard_config, dp_group=self.dp_group, tp_group=self.tp_group, + sp_group=self.sp_group, use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index ae3956c69..ccf122695 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -161,7 +161,7 @@ class ProcessGroupMesh: @staticmethod def get_coords_along_axis( - base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int] + base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]] ) -> List[Tuple[int, ...]]: """Get coordinates along the given axis. @@ -173,13 +173,28 @@ class ProcessGroupMesh: Returns: List[Tuple[int, ...]]: Coordinates along the axis. """ - coords_in_group = [] - for idx in indices_at_axis: - coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) + if isinstance(axis, int): + axis = [axis,] + assert isinstance(indices_at_axis[0], int) + indices_at_axis = [indices_at_axis,] + + def add_index(base_coord, axis, indices_at_axis): + coords_in_group = [] + for idx in indices_at_axis: + coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) + return coords_in_group + + coords_in_group = [base_coord] + for ax, indices_at_ax in zip(axis, indices_at_axis): + new_coords_in_group = [] + for coords in coords_in_group: + new_coords_in_group += add_index(coords, ax, indices_at_ax) + coords_in_group = new_coords_in_group + return coords_in_group def create_group_along_axis( - self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None ) -> ProcessGroup: """Create all process groups along the given axis, and return the one which the current process belongs to. @@ -191,10 +206,17 @@ class ProcessGroupMesh: Returns: ProcessGroup: The process group along the given axis which the current process belongs to. """ - indices_at_axis = indices_at_axis or list(range(self._shape[axis])) + if isinstance(axis, int): + axis = [axis,] + if indices_at_axis is not None: + assert isinstance(indices_at_axis[0], int) + indices_at_axis = [indices_at_axis,] + + indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis] reduced_shape = list(self._shape) # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` - reduced_shape[axis] = 1 + for ax in axis: + reduced_shape[ax] = 1 target_group = None # use Cartesian product to generate all combinations of coordinates for base_coord in itertools.product(*[range(s) for s in reduced_shape]): @@ -225,4 +247,3 @@ class ProcessGroupMesh: # no need to cache it explicitly, since it will be cached in `create_group_along_axis` return self.create_group_along_axis(axis, indices_at_axis, backend=backend) return self._ranks_to_group[ranks_in_group] - \ No newline at end of file diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index c9b4317a6..0e368dbf9 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,4 +1,5 @@ from .attn import AttnMaskType, ColoAttention +from ._operation import all_to_all_comm from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row @@ -26,4 +27,5 @@ __all__ = [ "ParallelModule", "AttnMaskType", "ColoAttention", + "all_to_all_comm", ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 241770901..82d37bb4c 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -167,6 +167,97 @@ class LinearWithAsyncCommunication(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None +def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): + # currently only support one single tensor as output + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + # output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + + # initialization of ring communication + recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + rank_map = list(dist.get_process_group_ranks(process_group)) + recv_rank = rank_map[recv_rank] + send_rank = rank_map[send_rank] + recv_tensors = {} + send_tensors = {} + for k, v in input_to_gather.items(): + recv_tensors[k] = torch.empty_like(v) + send_tensors[k] = v.clone() + + def communicate_step(): + comm_ops = [] + for k in recv_tensors: + comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group)) + comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group)) + return dist.batch_isend_irecv(comm_ops) + + def switch_step(): + for k in recv_tensors: + send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k] + + output_tensors = [] + + handles = communicate_step() + # first round: special case, retrive from local tensor + output_tensors.append(func(**input_to_gather, **input_local)) + for i in range(group_size - 2): + for handle in handles: + handle.wait() + + switch_step() + + handles = communicate_step() + + # actual computation + output_tensors.append(func(**send_tensors, **input_local)) + + # final round: special case, no need to send/recv again + for handle in handles: + handle.wait() + output_tensors.append(func(**recv_tensors, **input_local)) + + return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) + + +class _GatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.process_group = process_group + ctx.dim = dim + + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + # do reduce-scatter + new_shape = list(grad_output.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + grad_list = [ + item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None + + class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward @@ -178,7 +269,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -186,12 +277,25 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): ctx.dim = dim ctx.overlap = overlap - input_parallel = _gather(input_, dim, process_group) + if ring is True: + input_to_gather = {"input": input_} + input_local = {"weight": weight} - if bias is not None: - output = F.linear(input_parallel, weight, bias) + output = _ring_as_gather( + F.linear, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + ) + + if bias is not None: + output += bias else: - output = F.linear(input_parallel, weight) + input_parallel = _gather(input_, dim, process_group) + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) return output @@ -294,11 +398,146 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None + + +def _ring_as_reducescatter( + func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1 +): + # currently only support one single tensor as output + group_size = dist.get_world_size(process_group) + cur_rank = dist.get_rank(process_group) + + # initialization of ring communication + recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 + send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + rank_map = list(dist.get_process_group_ranks(process_group)) + recv_rank = rank_map[recv_rank] + send_rank = rank_map[send_rank] + input_tensors = [] + for _ in range(group_size): + input_tensors.append({}) + for k, v in input_to_reducescatter.items(): + input_shape = v.shape + assert input_shape[reducescatter_dim] % group_size == 0 + _input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim)) + for i in range(group_size): + input_tensors[i][k] = _input_tensors[i] + input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] + input_tensors.reverse() + + output_tensor = func(**input_tensors[0], **input_local) + recv_tensor = torch.empty_like(output_tensor) + send_tensor = output_tensor.clone() + + def communicate_step(): + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + return dist.batch_isend_irecv([recv_op, send_op]) + + handles = communicate_step() + # first round: special case, retrive from local tensor + for i in range(group_size - 2): + # actual computation + output_tensor = func(**input_tensors[i + 1], **input_local) + + for handle in handles: + handle.wait() + output_tensor += recv_tensor + + tmp_tensor = send_tensor + send_tensor = output_tensor + output_tensor = tmp_tensor + + handles = communicate_step() + + # final round: special case, no need to send/recv again + output_tensor = func(**input_tensors[-1], **input_local) + for handle in handles: + handle.wait() + output_tensor += recv_tensor + return output_tensor class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): - """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, dim, ring): + ctx.save_for_backward(input_, weight, bias) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.dim = dim + + if ring is True: + input_to_reducescatter = {"input": input_} + input_local = {"weight": weight} + + if bias is not None: + input_to_reducescatter["bias"] = bias + + output = _ring_as_reducescatter( + F.linear, + input_to_reducescatter=input_to_reducescatter, + input_local=input_local, + process_group=process_group, + ) + else: + if bias is not None: + partial_output = F.linear(input_, weight, bias) + else: + partial_output = F.linear(input_, weight) + + output_shape = list(partial_output.shape) + assert ( + output_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) + + output_list = [ + item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() + dist.reduce_scatter(output, output_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight, bias = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + if use_bias: + bias = bias.view(bias.shape) + + grad_output = _gather(grad_output, dim, process_group) + + # TODO Need to fully optimize + total_input = input_ + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + return grad_input, grad_weight, grad_bias, None, None, None + + +class _ReduceScatterForwardGatherBackward(torch.autograd.Function): + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. @@ -343,7 +582,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -351,9 +590,24 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): ctx.dim = dim ctx.overlap = overlap - input_parallel = _gather(input_, dim, process_group) + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather["input"] = input_ + input_local["other"] = weight - output = torch.matmul(input_parallel, weight) + output = _ring_as_gather( + torch.matmul, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + gather_dim=dim, + ) + + else: + input_parallel = _gather(input_, dim, process_group) + + output = torch.matmul(input_parallel, weight) if bias is not None: output = output + bias @@ -433,7 +687,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -448,14 +702,17 @@ class _SplitForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group): + def forward(ctx, input_, dim, process_group, grad_scale=None): ctx.process_group = process_group ctx.dim = dim + ctx.grad_scale = grad_scale return _split(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): - return _gather(grad_output, ctx.dim, ctx.process_group), None, None + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None class _ReduceForward(torch.autograd.Function): @@ -505,14 +762,50 @@ class _GatherForwardSplitBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, dim, process_group): + def forward(ctx, input_, dim, process_group, grad_scale=None): ctx.process_group = process_group ctx.dim = dim + ctx.grad_scale = grad_scale return _gather(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): - return _split(grad_output, ctx.dim, ctx.process_group), None, None + if ctx.grad_scale is not None: + grad_output = grad_output * ctx.grad_scale + return _split(grad_output, ctx.dim, ctx.process_group), None, None, None + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + world_size = dist.get_world_size(process_group) + bsz, _, _ = input_.shape + + # using all_to_all_single when batch size is 1 + if bsz == 1: + return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) + else: + return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) + + @staticmethod + def backward(ctx, *grad_output): + process_group = ctx.process_group + scatter_dim = ctx.gather_dim + gather_dim = ctx.scatter_dim + return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) + return (return_grad, None, None, None) class HookParameter(torch.autograd.Function): @@ -608,6 +901,40 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): + inp_shape = list(input_.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size + if scatter_dim < 2: + input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous() + else: + input_t = ( + input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]) + .transpose(0, 1) + .contiguous() + ) + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + if scatter_dim < 2: + output = output.transpose(0, 1).contiguous() + + return output.reshape( + inp_shape[:gather_dim] + + [ + inp_shape[gather_dim] * seq_world_size, + ] + + inp_shape[gather_dim + 1 :] + ).contiguous() + + def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) @@ -617,31 +944,39 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre def linear_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _LinearWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) -def linear_reducescatter_forward_gather_backward(input_, process_group, dim): - return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) +def gather_forward_reducescatter_backward(input_, process_group, dim): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) + + +def reducescatter_forward_gather_backward(input_, process_group, dim): + return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + +def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring) def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) -def gather_forward_split_backward(input_, dim, process_group): - return _GatherForwardSplitBackward.apply(input_, dim, process_group) +def gather_forward_split_backward(input_, dim, process_group, grad_scale=None): + return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale) -def split_forward_gather_backward(input_, dim, process_group): - return _SplitForwardGatherBackward.apply(input_, dim, process_group) +def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): + return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) def reduce_forward(input_, process_group): @@ -650,3 +985,7 @@ def reduce_forward(input_, process_group): def reduce_backward(input_, process_group): return _ReduceBackward.apply(input_, process_group) + + +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index eeb0ef399..7c8619ad8 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -23,11 +23,13 @@ from colossalai.tensor.d_tensor.api import ( ) from ._operation import ( + gather_forward_reducescatter_backward, gather_forward_split_backward, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, + reducescatter_forward_gather_backward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -74,7 +76,7 @@ class Linear1D_Col(ParallelModule): device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, - seq_parallel: bool = False, + seq_parallel_mode: str = None, seq_parallel_dim: int = 1, overlap: torch.cuda.Stream = None, skip_bias_add: bool = False, @@ -89,7 +91,7 @@ class Linear1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.overlap = overlap self.skip_bias_add = skip_bias_add @@ -196,12 +198,18 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel: - output_parallel = linear_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap - ) - else: + + if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + elif self.seq_parallel_mode == "split_gather": + input_parallel = gather_forward_reducescatter_backward( + input_parallel, self.process_group, self.seq_parallel_dim + ) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) + elif self.seq_parallel_mode == "ring": + output_parallel = linear_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True + ) if self.gather_output: # All-gather across the partitions. @@ -225,7 +233,8 @@ class Linear1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None. + seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): @@ -245,7 +254,7 @@ class Linear1D_Row(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - seq_parallel: bool = False, + seq_parallel_mode: str = None, seq_parallel_dim: int = 1, parallel_input: bool = True, skip_bias_add: bool = False, @@ -265,7 +274,7 @@ class Linear1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) @@ -403,18 +412,26 @@ class Linear1D_Row(ParallelModule): output_parallel_list[i], group=self.process_group, async_op=True ) handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) - if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward( + if self.seq_parallel_mode is None: + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reduce_forward(output_parallel, self.process_group) + elif self.seq_parallel_mode == "split_gather": + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reducescatter_forward_gather_backward( output_parallel, self.process_group, self.seq_parallel_dim ) - else: - output = reduce_forward(output_parallel, self.process_group) + elif self.seq_parallel_mode == "ring": + output = linear_reducescatter_forward_gather_backward( + input_, + self.weight, + process_group=self.process_group, + dim=self.seq_parallel_dim, + ring=True, + ) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 12476d050..dc3634238 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,12 +25,12 @@ from colossalai.tensor.d_tensor.api import ( from ._operation import ( gather_forward_split_backward, - linear_reducescatter_forward_gather_backward, linear_with_async_comm, matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, reduce_backward, reduce_forward, + reducescatter_forward_gather_backward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -150,7 +150,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): device (`torch.device`): The device of parameters, defaults to None. n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False @@ -175,7 +175,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, - seq_parallel: bool = False, + seq_parallel_mode: str = None, overlap: bool = False, skip_bias_add: bool = False, n_fused: int = 3, @@ -190,7 +190,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): self.in_features = in_features self.out_features = out_features self.gather_output = gather_output - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device @@ -312,17 +312,22 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel: - input_parallel = input_ - output_parallel = matmul_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap - ) - else: + if self.seq_parallel_mode is None: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) output_parallel = matmul_with_async_comm( input_parallel, self.weight, bias, self.process_group, self.async_communication ) + elif self.seq_parallel_mode == "split_gather": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + ) + elif self.seq_parallel_mode == "ring": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True + ) if self.gather_output: # All-gather across the partitions. @@ -347,7 +352,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None. which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. @@ -366,7 +371,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - seq_parallel: bool = False, + seq_parallel_mode: str = None, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -385,7 +390,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group - self.seq_parallel = seq_parallel + self.seq_parallel_mode = seq_parallel_mode self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -528,11 +533,15 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = torch.matmul(input_, self.weight) - if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) - else: + if self.seq_parallel_mode is None: + output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group) + elif self.seq_parallel_mode == "split_gather": + output_parallel = torch.matmul(input_, self.weight) + output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + elif self.seq_parallel_mode == "ring": + output_parallel = torch.matmul(input_, self.weight) + output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) if not self.skip_bias_add: if self.bias is not None: @@ -702,7 +711,6 @@ class FusedLinear1D_Col(ParallelModule): # process_group=process_group, # is_transposed=False) # linear_1d.bias.data.copy_(sharded_bias.data) - print(linear_1d.weight.shape) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 0d2cc1b33..9c6ced445 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -35,17 +35,21 @@ class SeqParallelUtils: return getattr(param, "partial_derived", False) @staticmethod - def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None): + def allreduce_partial_data_grad( + process_group: ProcessGroup, + model: nn.Module = None, + grads: List[torch.Tensor] = None, + ): """ Allreduce partial derived gradients across the specified process group. This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism. Args: - tp_group (ProcessGroup): The process group for gradient synchronization. + process_group (ProcessGroup): The process group for gradient synchronization. model (nn.Module): The model from which gradients will be synchronized. grads (List[torch.Tensor]): The list of gradients to be synchronized. - + only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived. Raises: AssertionError: If both `model` and `grads` are provided or neither is provided. """ @@ -53,22 +57,26 @@ class SeqParallelUtils: assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None." # Get the size of the process group, which determines whether synchronization is needed. - tp_size = get_world_size(tp_group) if tp_group is not None else 1 + group_size = get_world_size(process_group) if process_group is not None else 1 - if tp_size == 1: + if group_size == 1: # If the process group size is 1, no synchronization is required. return if model is not None: # If `model` is provided, extract partial derived gradients from the model's parameters. grads = [] + for p in model.parameters(): - if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p): - grads.append(p.grad.data) + if p.grad is not None: + if SeqParallelUtils.is_sp_partial_derived_param(p): + grads.append(p.grad.data) # Flatten and reduce the gradients using the specified process group. + if len(grads) == 0: + return coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group) # Unflatten the synchronized gradients and update the model's gradients. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): @@ -76,7 +84,7 @@ class SeqParallelUtils: else: # If `grads` are provided explicitly, synchronize those gradients directly. coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) + dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 7411e1d0e..0838fcee6 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -186,13 +186,14 @@ class BertPipelineForwards: # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) - if encoder_hidden_states is not None: - encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group ) + if encoder_hidden_states is not None: + encoder_hidden_states = split_forward_gather_backward( + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: @@ -240,9 +241,10 @@ class BertPipelineForwards: # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index d94c30d29..fe70376e1 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -213,10 +213,11 @@ class BloomPipelineForwards: # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) start_idx, end_idx = stage_index[0], stage_index[1] for i, (block, layer_past) in enumerate( @@ -261,10 +262,11 @@ class BloomPipelineForwards: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): # Add last hidden state diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index a3e000e6e..9207b34d0 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -191,12 +191,11 @@ class ChatGLMPipelineForwards: all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -222,12 +221,11 @@ class ChatGLMPipelineForwards: if use_cache: presents = presents + (kv_cache,) - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index ea22cfb15..1306c8aa6 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -218,12 +218,13 @@ class GPT2PipelineForwards: # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] @@ -278,12 +279,13 @@ class GPT2PipelineForwards: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # When sequence parallelism done, gather the output tensor in forward and split it in backward - if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode == "split_gather": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) @@ -1141,7 +1143,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = split_forward_gather_backward( hidden_states, dim=1, - process_group=shard_config.tensor_parallel_process_group, + process_group=shard_config.sequence_parallel_process_group, ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -1208,7 +1210,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states = gather_forward_split_backward( hidden_states, dim=1, - process_group=shard_config.tensor_parallel_process_group, + process_group=shard_config.sequence_parallel_process_group, ) hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index eb421c92b..0f1b4ad0a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,18 +1,32 @@ +import math import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaModel, + apply_rotary_pos_emb, + repeat_kv, +) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d @@ -438,7 +452,7 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config: ShardConfig): +def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb llama_version = 2 @@ -459,18 +473,30 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -490,6 +516,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value @@ -726,3 +755,261 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ) return forward + + +def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + return forward + + +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): + logger = logging.get_logger(__name__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + # modify past_key_values_length when using sequence parallel + past_key_values_length *= sp_size + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return forward diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index cd7bdcdd6..0a61d8cff 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -66,8 +67,17 @@ class BertPolicy(Policy): else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert" + if sp_mode == "ring": + warnings.warn( + f"For Bert, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" + overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode == "split_gather" + if self.shard_config.enable_tensor_parallelism: policy[BertLayer] = ModulePolicyDescription( attribute_replacement={ @@ -85,7 +95,7 @@ class BertPolicy(Policy): suffix="attention.self.query", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), @@ -93,7 +103,7 @@ class BertPolicy(Policy): suffix="attention.self.key", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), @@ -101,7 +111,7 @@ class BertPolicy(Policy): suffix="attention.self.value", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), @@ -112,7 +122,7 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="attention.output.dropout", @@ -122,14 +132,14 @@ class BertPolicy(Policy): suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), SubModuleReplacementDescription( suffix="output.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="output.dropout", @@ -151,7 +161,7 @@ class BertPolicy(Policy): ] ) - if use_sequence_parallel: + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, @@ -165,12 +175,12 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.output.LayerNorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="output.LayerNorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 55b69d5f0..2becadc3f 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -55,8 +56,18 @@ class BloomPolicy(Policy): norm_cls = col_nn.FusedLayerNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM" + if sp_mode == "ring": + warnings.warn( + f"For BLOOM, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" + overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode == "split_gather" + if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription( attribute_replacement={ @@ -70,12 +81,12 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), SubModuleReplacementDescription( suffix="self_attention.attention_dropout", @@ -84,12 +95,12 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="mlp.dense_h_to_4h", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={"seq_parallel_mode": sp_mode}, ), ], ) @@ -132,19 +143,19 @@ class BloomPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, target_key=BloomBlock, ) - if use_sequence_parallel: + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 0830d85f1..dabc14bff 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -55,8 +56,17 @@ class ChatGLMPolicy(Policy): norm_cls = col_nn.RMSNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2" + if sp_mode == "ring": + warnings.warn( + f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode == "split_gather" + if self.shard_config.enable_tensor_parallelism: policy[ChatGLMModel] = ModulePolicyDescription( attribute_replacement={}, @@ -91,12 +101,12 @@ class ChatGLMPolicy(Policy): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap}, + kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap}, ), SubModuleReplacementDescription( suffix="self_attention.dense", target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0}, + kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0}, ), SubModuleReplacementDescription( suffix="self_attention.core_attention.attention_dropout", @@ -110,12 +120,12 @@ class ChatGLMPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -145,7 +155,7 @@ class ChatGLMPolicy(Policy): ) # use sequence parallel - if use_sequence_parallel: + if sp_mode == "split_gather": self.append_or_create_method_replacement( description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 4bcac3951..380a432dc 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -50,8 +51,25 @@ class GPT2Policy(Policy): norm_cls = col_nn.FusedLayerNorm else: norm_cls = col_nn.LayerNorm - use_sequence_parallel = self.shard_config.enable_sequence_parallelism + + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2" + if sp_mode == "ring": + warnings.warn( + f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + ) + sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap + sp_partial_derived = sp_mode in ["split_gather", "ring"] + use_flash_attention = self.shard_config.enable_flash_attention + # todo: currently sp cannot be used with flashattention + if sp_mode in ["split_gather", "ring", "all_to_all"]: + if use_flash_attention: + warnings.warn( + f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically." + ) + self.shard_config.enable_flash_attention = False + use_flash_attention = False if self.shard_config.enable_tensor_parallelism: policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ @@ -78,7 +96,7 @@ class GPT2Policy(Policy): target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ "n_fused": 3, - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), @@ -86,7 +104,7 @@ class GPT2Policy(Policy): suffix="attn.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, kwargs={ - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, }, ), SubModuleReplacementDescription( @@ -94,14 +112,16 @@ class GPT2Policy(Policy): target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ "n_fused": 1, - "seq_parallel": use_sequence_parallel, + "seq_parallel_mode": sp_mode, "overlap": overlap, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel_mode": sp_mode, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -133,25 +153,25 @@ class GPT2Policy(Policy): SubModuleReplacementDescription( suffix="ln_1", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="ln_2", target_module=norm_cls, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="ln_cross_attn", target_module=norm_cls, ignore_if_not_exist=True, - kwargs={"sp_partial_derived": use_sequence_parallel}, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, target_key=GPT2Block, ) - if self.shard_config.enable_flash_attention: + if use_flash_attention: self.append_or_create_method_replacement( description={ "forward": get_gpt2_flash_attention_forward(), @@ -164,7 +184,7 @@ class GPT2Policy(Policy): "forward": get_gpt_model_forward_for_flash_attn(self.shard_config) } - if self.shard_config.enable_sequence_parallelism: + if sp_mode is not None: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} return policy diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 18d79f84a..bb4551b2c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -12,6 +12,8 @@ from ..modeling.llama import ( LlamaPipelineForwards, get_llama_flash_attention_forward, get_llama_model_forward_for_flash_attn, + get_llama_seq_parallel_attention_forward, + get_llama_seq_parallel_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -45,9 +47,74 @@ class LlamaPolicy(Policy): else: norm_cls = RMSNorm - if self.shard_config.enable_sequence_parallelism: + if self.pipeline_stage_manager is not None: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + self.shard_config.enable_sequence_overlap = False + self.shard_config.sequence_parallelism_mode = None + warnings.warn( + f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" + ) + sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None + sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None + sp_group = ( + self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None + ) + sp_partial_derived = sp_mode in ["split_gather", "ring"] + + use_flash_attention = self.shard_config.enable_flash_attention + # Currently sp cannot to be used with flashattention + if sp_mode in ["split_gather", "ring", "all_to_all"]: + if use_flash_attention: + warnings.warn( + f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically." + ) + use_flash_attention = False + + if sp_mode in ["split_gather", "ring"]: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward( + sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group + ), + }, + policy=policy, + target_key=LlamaModel, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=LlamaAttention, + ) + elif sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[LlamaAttention] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=LlamaAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward( + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { @@ -65,30 +132,37 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), ], ) @@ -108,10 +182,12 @@ class LlamaPolicy(Policy): SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -122,16 +198,17 @@ class LlamaPolicy(Policy): description=SubModuleReplacementDescription( suffix="norm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, target_key=LlamaModel, ) # use flash attention - if self.shard_config.enable_flash_attention: + if use_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(self.shard_config), + "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size), }, policy=policy, target_key=LlamaAttention, @@ -243,7 +320,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy): policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism: # add a new item for casual lm new_item = { LlamaForCausalLM: ModulePolicyDescription( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index ce78a7e94..7489873c2 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,3 +1,4 @@ +import warnings from dataclasses import dataclass, field from typing import Any, Dict, Optional @@ -9,6 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from .grad_ckpt_config import GradientCheckpointConfig __all__ = ["ShardConfig"] +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @dataclass @@ -29,13 +31,15 @@ class ShardConfig: enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None + sequence_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None enable_tensor_parallelism: bool = True + enable_all_optimization: bool = False enable_fused_normalization: bool = False enable_flash_attention: bool = False enable_jit_fused: bool = False - enable_all_optimization: bool = False enable_sequence_parallelism: bool = False + sequence_parallelism_mode: str = None enable_sequence_overlap: bool = False parallel_output: bool = True gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None @@ -50,22 +54,57 @@ class ShardConfig: def tensor_parallel_size(self): return self._tensor_parallel_size + @property + def sequence_parallel_size(self): + return self._sequence_parallel_size + def __post_init__(self): - if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: - raise ValueError( - "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True" - ) - if not self.enable_sequence_parallelism and self.enable_sequence_overlap: - raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") - if not self.enable_tensor_parallelism: - self._tensor_parallel_size = 1 - else: - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: self._turn_on_all_optimization() + if self.enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + "split_gather" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + self.enable_tensor_parallelism + ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True" + elif self.sequence_parallelism_mode in ["all_to_all"]: + assert ( + not self.enable_tensor_parallelism + ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False" + if self.enable_sequence_overlap: + self.enable_sequence_overlap = False + warnings.warn( + f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}" + ) + else: + if self.sequence_parallelism_mode: + self.sequence_parallelism_mode = None + warnings.warn( + f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False" + ) + assert ( + not self.enable_sequence_overlap + ), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True" + + # get the tensor parallel size + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + + # get the sequence parallel size + if not self.enable_sequence_parallelism: + self._sequence_parallel_size = 1 + else: + self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) + def _turn_on_all_optimization(self): """ Turn on all optimization. diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a2433d1b2..bbbaf13b5 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -79,6 +79,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -494,7 +495,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # clear reduced grads if self._overlap_communication: get_accelerator().synchronize() - self.zero_grad() def backward_by_grad(self, tensor, grad): diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 24f9627c2..ab5d97420 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -18,8 +18,23 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt') # input_ids = tokenized_input['input_ids'] # attention_mask = tokenized_input['attention_mask'] - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + # input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) + # attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + input_ids = torch.tensor( + [ + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + ], + dtype=torch.int64, + ) + attention_mask = torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ], + dtype=torch.int64, + ) + return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -35,9 +50,9 @@ def data_gen_for_question_answering(): # question answering data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - start_positions = torch.tensor([0], dtype=torch.int64) + start_positions = torch.tensor([[0], [0]], dtype=torch.int64) data["start_positions"] = start_positions - end_positions = torch.tensor([1], dtype=torch.int64) + end_positions = torch.tensor([[1], [1]], dtype=torch.int64) data["end_positions"] = end_positions return data @@ -46,14 +61,20 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) + data["labels"] = torch.tensor( + [ + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], + ], + dtype=torch.int64, + ) return data def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() - data["labels"] = torch.tensor([1], dtype=torch.int64) + data["labels"] = torch.tensor([[1], [1]], dtype=torch.int64) return data @@ -61,12 +82,18 @@ def date_gen_for_double_heads(): num_choices = 2 batch_size = 2 input_ids = torch.tensor( - [[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]], + [ + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + [15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779], + ], + dtype=torch.int64, + ) + attention_mask = torch.tensor( + [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64, ) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) + mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64) mc_token_ids = mc_token_ids.expand((batch_size, num_choices)) multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous() @@ -103,6 +130,7 @@ config = transformers.GPT2Config( hidden_dropout=0, problem_type="single_label_classification", pad_token_id=50256, + tie_word_embeddings=True, ) config_for_token_classification = copy.deepcopy(config) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 9f801e0cc..58b5b0487 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -28,9 +28,19 @@ if HAS_LLAMA: # ----------------------------------- input_ids = torch.Tensor( - [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + [ + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], + ] ).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() + + attention_mask = torch.Tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ] + ).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for casual lm diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index ece3b4036..ac6f8caef 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -44,7 +44,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() - enable_all_optimization = True if tp_size > 1 else False + + enable_flash_attention = True if tp_size > 1 else False + enable_fused_normalization = True if tp_size > 1 else False + enable_jit_fused = True if tp_size > 1 else False with shared_tempdir() as tempdir: pretrained_path = os.path.join(tempdir, "pretrained") @@ -54,7 +57,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b plugin = GeminiPlugin( **placement_config, tp_size=tp_size, - enable_all_optimization=enable_all_optimization, + enable_flash_attention=enable_flash_attention, + enable_fused_normalization=enable_fused_normalization, + enable_jit_fused=enable_jit_fused, extra_dp_size=extra_dp_size, ) booster = Booster(plugin=plugin) @@ -80,7 +85,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - enable_all_optimization = True if tp_size > 1 else False + enable_flash_attention = True if tp_size > 1 else False + enable_fused_normalization = True if tp_size > 1 else False + enable_jit_fused = True if tp_size > 1 else False extra_dp_size = dist.get_world_size() // (zero_size * tp_size) plugin = GeminiPlugin( **placement_config, @@ -88,7 +95,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, - enable_all_optimization=enable_all_optimization, + enable_flash_attention=enable_flash_attention, + enable_fused_normalization=enable_fused_normalization, + enable_jit_fused=enable_jit_fused, ) booster = Booster(plugin=plugin) diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py index 08542d1f6..3d206622d 100644 --- a/tests/test_cluster/test_process_group_mesh.py +++ b/tests/test_cluster/test_process_group_mesh.py @@ -84,6 +84,30 @@ def check_process_group_mesh_with_cases(): 2: [2], 3: [3], } + TPxPP_RANKS_IN_GROUP = { + 0: [0, 1, 2, 3], + 1: [0, 1, 2, 3], + 2: [0, 1, 2, 3], + 3: [0, 1, 2, 3], + } + DPxTP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + TPxPP_PARTIAL_INDICES = { + 0: [[0, 1], [0]], + 1: [[1], [0, 1]], + 2: [[0], [0, 1]], + 3: [[0, 1], [1]], + } + TPxPP_RANKS_IN_GROUP_PARTIAL = { + 0: [0, 1], + 1: [1, 3], + 2: [0, 2], + 3: [2, 3], + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) @@ -107,6 +131,12 @@ def check_process_group_mesh_with_cases(): assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] dp_group = pg_mesh.get_group_along_axis(DP_DIM) assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] + dpxtp_group = pg_mesh.create_group_along_axis([DP_DIM, TP_DIM]) + assert pg_mesh.get_ranks_in_group(dpxtp_group) == DPxTP_RANKS_IN_GROUP[rank] + tpxpp_group = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM]) + assert pg_mesh.get_ranks_in_group(tpxpp_group) == TPxPP_RANKS_IN_GROUP[rank] + tpxpp_group_partial = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM], TPxPP_PARTIAL_INDICES[rank]) + assert pg_mesh.get_ranks_in_group(tpxpp_group_partial) == TPxPP_RANKS_IN_GROUP_PARTIAL[rank] # check prev rank if RANK_TO_COORDINATE[rank][TP_DIM] != 0: diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index e056860ed..e9aa0dbed 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -56,13 +56,18 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( - linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap + linear_copy, + process_group=None, + gather_output=True, + seq_parallel_mode=seq_parallel_mode, + n_fused=3, + overlap=overlap, ) assert linear.weight.shape == torch.Size([48, 192]) @@ -79,7 +84,9 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool) # check computation correctness x = torch.rand(1, 4, 48).cuda() out = linear(x) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + ) gather_out = linear_conv_col(x_for_shard) assert_close(rearrange(out, -1), gather_out) @@ -91,14 +98,14 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool) assert_close(target_grad, linear_conv_col.weight.grad) -def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() linear_row = GPT2FusedLinearConv1D_Row.from_native_module( - linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode ) assert linear.weight.shape == torch.Size([48, 192]) @@ -115,7 +122,7 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] assert_close(target_out, gather_out) # check backward correctness @@ -128,11 +135,11 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): @parameterize("lazy_init", [False, True]) -@parameterize("seq_parallel", [False, True]) +@parameterize("seq_parallel_mode", ["split_gather", None]) @parameterize("overlap", [True]) -def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): - check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) - check_linear_conv_1d_row(lazy_init, seq_parallel) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap) + check_linear_conv_1d_row(lazy_init, seq_parallel_mode) def run_dist(rank, world_size, port): diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index defa4afb9..21d3190de 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -15,13 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() linear_col = Linear1D_Col.from_native_module( - linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap + linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap ) # ensure that the parameters are distributed @@ -43,7 +43,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + ) x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -63,20 +65,20 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): assert x_for_unshard.grad is not None target_unshard_gard = ( x_for_unshard.grad - if seq_parallel is False + if seq_parallel_mode is None else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] ) assert_close(target_unshard_gard, x_for_shard.grad) -def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): +def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() linear_row = Linear1D_Row.from_native_module( - linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode ) assert linear_row.weight.shape == torch.Size([128, 16]) @@ -98,7 +100,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] assert_close(target_out, gather_out) # check backward correctness @@ -115,7 +117,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -125,10 +127,10 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() linear_col = Linear1D_Col.from_native_module( - linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap + linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap ) linear_row = Linear1D_Row.from_native_module( - linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel + linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode ) linear_1.load_state_dict(linear_col.state_dict()) @@ -141,13 +143,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + x_for_shard = ( + x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + ) x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + target_out = ( + unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + ) assert_close(target_out, shard_out) # check backward correctness @@ -163,19 +169,19 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool assert x_for_unshard.grad is not None target_unshard_gard = ( x_for_unshard.grad - if seq_parallel is False + if seq_parallel_mode is None else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] ) assert_close(target_unshard_gard, x_for_shard.grad) @parameterize("lazy_init", [False, True]) -@parameterize("seq_parallel", [False, True]) +@parameterize("seq_parallel_mode", [None, "split_gather"]) @parameterize("overlap", [True]) -def run_dist_linear_test(lazy_init, seq_parallel, overlap): - check_linear_1d_col(lazy_init, seq_parallel, overlap) - check_linear_1d_row(lazy_init, seq_parallel) - check_linear_col_plus_row(lazy_init, seq_parallel, overlap) +def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap): + check_linear_1d_col(lazy_init, seq_parallel_mode, overlap) + check_linear_1d_row(lazy_init, seq_parallel_mode) + check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap) def check_dist_linear(rank, world_size, port): diff --git a/tests/test_shardformer/test_layer/test_sequence_parallel.py b/tests/test_shardformer/test_layer/test_sequence_parallel.py new file mode 100644 index 000000000..13b1a13e7 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_sequence_parallel.py @@ -0,0 +1,178 @@ +import copy + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import all_to_all_comm +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +class SequenceParallelAttention(torch.nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local attention with q,k,v + sequence_process_group (ProcessGroup): sequence parallel process group + scatter_idx (int): scatter_idx for all2all comm + gather_idx (int): gather_idx for all2all comm + """ + + def __init__( + self, + heads_num: torch.Tensor, + hidden_dim: torch.Tensor, + enable_sequence_parallellism: bool = False, + sequence_process_group: dist.ProcessGroup = None, + scatter_idx: int = 2, + gather_idx: int = 1, + ) -> None: + super(SequenceParallelAttention, self).__init__() + self.spg = sequence_process_group + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + self.heads_num = heads_num + self.hidden_dim = hidden_dim + assert hidden_dim % heads_num == 0 + self.head_dim = hidden_dim // heads_num + self.enable_sequence_parallellism = enable_sequence_parallellism + + self.q = nn.Linear(hidden_dim, hidden_dim) + self.k = nn.Linear(hidden_dim, hidden_dim) + self.v = nn.Linear(hidden_dim, hidden_dim) + self.out = nn.Linear(hidden_dim, hidden_dim) + + def attn(self, q, k, v): + batch_size, seq_len = q.shape[0], q.shape[1] + + scale = self.head_dim**0.5 + qk = torch.matmul(q, k.transpose(-2, -1)) / scale + weights = F.softmax(qk, dim=-1) + + attention_score = torch.matmul(weights, v) + + return attention_score + + def forward(self, x) -> Tensor: + bsz, q_len, _ = x.size() + + seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len + num_heads = ( + self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num + ) + + # in shape : e.g., [s/p:h:] + query_states = self.q(x) + key_states = self.k(x) + value_states = self.v(x) + + if self.enable_sequence_parallellism: + query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx) + key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx) + value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx) + + query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2) + # out shape : e.g., [s:h/p:] + attn_score = self.attn(query_states, key_states, value_states) + attn_score = attn_score.transpose(1, 2).contiguous() + attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim) + if self.enable_sequence_parallellism: + attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx) + + # output e.g., [s/p::h] + output = self.out(attn_score) + + return output + + +def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): + seq_len = seq_len + hidden_dim = hidden_dim + head_num = head_num + batch_size = batch_size + world_size = dist.get_world_size() + + x = torch.randn(batch_size, seq_len, hidden_dim).cuda() + x_unshard = x.clone() + x_unshard.requires_grad_(True) + x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()] + x_input.requires_grad_(True) + + # Multi-head Attention + mha = SequenceParallelAttention(head_num, hidden_dim).cuda() + # Multi-head Attention forward + mha_out = mha(x_unshard) + + # Sequence parallel Attention + sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda() + sp_attn.load_state_dict(copy.deepcopy(mha.state_dict())) + # Sequence parallel Attention forward + dist_attn_out = sp_attn(x_input) + + # gather the output of sequence parallel attention + out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)] + dist.all_gather(out_list, dist_attn_out) + seq_out = torch.cat(out_list, dim=1) + + # forward result check + assert_close(seq_out, mha_out) + + # Multi-head Attention backward + mha_out.sum().backward() + q_grad = mha.q.weight.grad + k_grad = mha.k.weight.grad + v_grad = mha.v.weight.grad + o_grad = mha.out.weight.grad + x_grad = x_unshard.grad + + # Sequence parallel Attention backward + dist_attn_out.sum().backward() + q_grad_seq = sp_attn.q.weight.grad + k_grad_seq = sp_attn.k.weight.grad + v_grad_seq = sp_attn.v.weight.grad + o_grad_seq = sp_attn.out.weight.grad + x_grad_seq = x_input.grad + # all_reduce the grad of sequence parallel attention weight + dist.all_reduce(q_grad_seq) + dist.all_reduce(k_grad_seq) + dist.all_reduce(v_grad_seq) + dist.all_reduce(o_grad_seq) + # gather the grad of sequence parallel attention input + x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)] + dist.all_gather(x_grad_seq_list, x_grad_seq) + x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1) + + # backward result check + assert_close(q_grad_seq, q_grad) + assert_close(k_grad_seq, k_grad) + assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4) + assert_close(o_grad_seq, o_grad) + assert_close(x_grad_seq_gather, x_grad) + + +@parameterize("seq_len", [128]) +@parameterize("hidden_dim", [64]) +@parameterize("head_num", [4]) +@parameterize("batch_size", [1]) +def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size): + seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size) + + +def check_all2all_attn(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_seq_parallel_attn() + + +@rerun_if_address_is_in_use() +def test_all_to_all_attention(): + spawn(check_all2all_attn, nprocs=4) + + +if __name__ == "__main__": + test_all_to_all_attention() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 85be9a242..d5fc2c30f 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,5 +1,4 @@ import copy -import math from contextlib import nullcontext from typing import Any, Callable, Dict, List, Optional @@ -123,7 +122,6 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c sharded_model = copy.deepcopy(org_model) if use_lazy_init: ctx.materialize(org_model) - org_model = org_model.cuda() org_optimizer = Adam(org_model.parameters(), lr=1e-3) sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) @@ -162,24 +160,22 @@ def run_forward_backward_with_hybrid_plugin( data = data_gen_fn() - if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0: - seq_len = data["input_ids"].shape[-1] - lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) - times = lcm // seq_len - input_shape = data["input_ids"].shape - for k, v in data.items(): - if v.shape == input_shape: - data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) + shard_test_data = {} + for k, v in data.items(): + shard_test_data[k] = data[k].clone() + unshard_test_data = {} + for k, v in data.items(): + unshard_test_data[k] = data[k].clone() sharded_model.train() if booster.plugin.stage_manager is not None: - for k, v in data.items(): + for k, v in shard_test_data.items(): if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 4 - data[k] = v.to("cuda").repeat(*new_shape) + shard_test_data[k] = v.to("cuda").repeat(*new_shape) - data_iter = iter([data]) + data_iter = iter([shard_test_data]) sharded_output = booster.execute_pipeline( data_iter, sharded_model, @@ -189,17 +185,22 @@ def run_forward_backward_with_hybrid_plugin( return_outputs=True, ) sharded_loss = sharded_output["loss"] - else: - data = {k: v.cuda() for k, v in data.items()} - sharded_output = sharded_model(**data) + else: + shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()} + sharded_output = sharded_model(**shard_test_data) sharded_loss = criterion(sharded_output) sharded_optimizer.backward(sharded_loss) org_model.train() - data = {k: v.cuda() for k, v in data.items()} - org_output = org_model(**data) - + if booster.plugin.stage_manager is not None: + for k, v in unshard_test_data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + unshard_test_data[k] = v.to("cuda").repeat(*new_shape) + unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()} + org_output = org_model(**unshard_test_data) org_loss = criterion(org_output) org_loss.backward() @@ -212,7 +213,6 @@ def check_output_hidden_state( stage_manager: Optional[PipelineStageManager] = None, atol: float = 1e-5, rtol: float = 1e-3, - dim: int = 0, ): org_hidden_state = org_output.last_hidden_state diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 768bd95bd..919557797 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -100,6 +100,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 1, @@ -154,7 +176,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ) def run_bert_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index b70cba8b4..cc0786618 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -99,6 +99,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 78d752b69..405ceba32 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -135,6 +135,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index d59d7e4ad..4aac7f3d4 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -131,6 +131,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 55858cbd4..27f904292 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,6 +2,8 @@ import os import pytest import torch +import torch.distributed as dist +from torch.testing import assert_close import colossalai from colossalai.logging import disable_existing_loggers @@ -46,6 +48,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] col_layer_for_check = ["layers[0].self_attn.o_proj"] + # Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism + norm_layer_for_check = ["layers[0].input_layernorm", "layers[0].post_attention_layernorm"] + + # During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled + if stage_manager is None: + norm_layer_for_check.append("norm") + + # Check the grad when using ZeRO-1 and ZeRO-2 + if ( + booster.plugin.zero_stage in [1, 2] + and booster.plugin.shard_config.enable_sequence_parallelism + and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" + ): + for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): + working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)] + grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p)) + grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank + grad = grads[grad_index] + sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()] + assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -60,8 +82,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, col_layer_grads = get_grad_tensors_for_check( llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False ) + norm_layer_grads = get_grad_tensors_for_check( + llama_model, + shard_llama_model, + norm_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) + grads_to_check.update(norm_layer_grads) # optimizer executes step org_optimizer.step() @@ -98,6 +131,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { + "tp_size": 2, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, From 15055f9a36d215e5a4fba8658e1b04c895881da6 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 7 Apr 2024 12:06:27 +0800 Subject: [PATCH 40/43] [hotfix] quick fixes to make legacy tutorials runnable (#5559) Co-authored-by: Edenzzzz --- colossalai/_analyzer/fx/tracer/tracer.py | 2 +- colossalai/nn/layer/scaled_softmax.py | 8 ++++++++ .../tutorial/auto_parallel/auto_parallel_with_resnet.py | 2 +- examples/tutorial/hybrid_parallel/train.py | 4 ++-- examples/tutorial/large_batch_optimizer/train.py | 6 +++--- examples/tutorial/opt/opt/run_clm.py | 4 ++-- .../tutorial/sequence_parallel/model/layers/bert_layer.py | 2 +- examples/tutorial/sequence_parallel/train.py | 4 ++-- 8 files changed, 20 insertions(+), 12 deletions(-) mode change 100755 => 100644 examples/tutorial/opt/opt/run_clm.py diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py index 17dce7672..36e8780af 100644 --- a/colossalai/_analyzer/fx/tracer/tracer.py +++ b/colossalai/_analyzer/fx/tracer/tracer.py @@ -237,7 +237,7 @@ class ColoTracer(Tracer): # override the tracer to support custom modules and checkpointing if self.trace_act_ckpt: orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply - orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant + orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant_generator def checkpoint(run_function, preserve_rng_state=False, *args): self.ckpt_regions.append(self.ckpt_idx) diff --git a/colossalai/nn/layer/scaled_softmax.py b/colossalai/nn/layer/scaled_softmax.py index a8d72ddd9..2e802db2d 100644 --- a/colossalai/nn/layer/scaled_softmax.py +++ b/colossalai/nn/layer/scaled_softmax.py @@ -8,6 +8,14 @@ import torch.nn as nn from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader +# NOTE: These kernels are compiled on specific GPU arch and not widely applicable. +# try: +# from colossalai._C import scaled_masked_softmax as scaled_masked_softmax, scaled_upper_triangle_masked_softmax_cuda as scaled_upper_triang_masked_softmax +# except ImportError: + +scaled_masked_softmax = None +scaled_upper_triang_masked_softmax = None + class AttnMaskType(enum.Enum): padding = 1 diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index 3c5b786b5..42b43bbc0 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -17,7 +17,7 @@ def synthesize_data(): def main(): - colossalai.launch_from_torch(config="./config.py") + colossalai.legacy.launch_from_torch(config="./config.py") logger = get_dist_logger() diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py index 95f1bf8ee..70445dff9 100644 --- a/examples/tutorial/hybrid_parallel/train.py +++ b/examples/tutorial/hybrid_parallel/train.py @@ -41,9 +41,9 @@ class DummyDataloader: def main(): # launch from torch - parser = colossalai.get_default_parser() + parser = colossalai.legacy.get_default_parser() args = parser.parse_args() - colossalai.launch_from_torch(config=args.config) + colossalai.legacy.launch_from_torch(config=args.config) # get logger logger = get_dist_logger() diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py index dd114b5af..c78292e0c 100644 --- a/examples/tutorial/large_batch_optimizer/train.py +++ b/examples/tutorial/large_batch_optimizer/train.py @@ -37,14 +37,14 @@ class DummyDataloader: def main(): # initialize distributed setting - parser = colossalai.get_default_parser() + parser = colossalai.legacy.get_default_parser() parser.add_argument( "--optimizer", choices=["lars", "lamb"], help="Choose your large-batch optimizer", required=True ) args = parser.parse_args() # launch from torch - colossalai.launch_from_torch(config=args.config) + colossalai.legacy.launch_from_torch(config=args.config) # get logger logger = get_dist_logger() @@ -73,7 +73,7 @@ def main(): ) # initialize - engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + engine, train_dataloader, test_dataloader, _ = colossalai.legacy.initialize( model=model, optimizer=optimizer, criterion=criterion, diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py old mode 100755 new mode 100644 index 3f0d04879..ae8a0f4a0 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -72,7 +72,7 @@ def get_time_stamp(): def parse_args(): - parser = colossalai.get_default_parser() + parser = colossalai.legacy.get_default_parser() parser.add_argument("-s", "--synthetic", action="store_true") parser.add_argument( "--dataset_name", @@ -289,7 +289,7 @@ class DummyDataloader: def main(): args = parse_args() disable_existing_loggers() - colossalai.launch_from_torch(config=dict()) + colossalai.legacy.launch_from_torch(config=dict()) logger = get_dist_logger() is_main_process = dist.get_rank() == 0 diff --git a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py index 1ef16ee6a..c7782bf04 100644 --- a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py +++ b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn -from colossalai.kernel.cuda_native import LayerNorm from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train from colossalai.legacy.nn.layer.parallel_sequence import TransformerSelfAttentionRing +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .dropout import get_bias_dropout_add from .mlp import TransformerMLP diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index f25fc8189..431af3586 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -48,7 +48,7 @@ def pipeline_data_process_func(stage_output, micro_batch_data): def main(): # initialize parse_args() - colossalai.launch_from_torch(config="./config.py", seed=1234, backend="nccl") + colossalai.legacy.launch_from_torch(config="./config.py", seed=1234, backend="nccl") logger = get_dist_logger() @@ -136,7 +136,7 @@ def main(): logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps") # # init - engine, *dummy = colossalai.initialize(model, optimizer, criterion, verbose=True) + engine, *dummy = colossalai.legacy.initialize(model, optimizer, criterion, verbose=True) # build timer timer = MultiTimer() From a799ca343b13665661a5e95f5ad1523457bef2e2 Mon Sep 17 00:00:00 2001 From: digger yu Date: Sun, 7 Apr 2024 18:42:15 +0800 Subject: [PATCH 41/43] [fix] fix typo s/muiti-node /multi-node etc. (#5448) --- .../ColossalEval/colossal_eval/evaluate/gpt_evaluate.py | 2 +- applications/ColossalMoE/train.py | 6 +++--- .../ColossalQA/colossalqa/data_loader/document_loader.py | 2 +- .../ColossalQA/colossalqa/local/colossalcloud_llm.py | 2 +- applications/ColossalQA/colossalqa/local/llm.py | 6 +++--- .../colossalqa/retrieval_conversation_universal.py | 2 +- applications/ColossalQA/colossalqa/retriever.py | 2 +- applications/ColossalQA/colossalqa/utils.py | 2 +- examples/language/openmoe/model/modeling_openmoe.py | 2 +- examples/language/openmoe/train.py | 6 +++--- 10 files changed, 16 insertions(+), 16 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py index a0b1ed114..19907daaf 100644 --- a/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py +++ b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py @@ -670,7 +670,7 @@ def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float: def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> int: """ Calculate the score from the response returned by gpt-3.5-turbo or gpt-4. - Different from text-davinci-003, this fuction directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4. + Different from text-davinci-003, this function directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4. Although text-davinci-003 can return log probabilities, it costs ten times as much as gpt-3.5-turbo. Args: diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index 99603282b..850236726 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -128,13 +128,13 @@ def parse_args(): parser.add_argument( "--comm_overlap", action="store_true", - help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", + help="Use communication overlap for MoE. Recommended to enable for multi-node training.", ) # hierarchical all-to-all parser.add_argument( "--hierarchical_alltoall", action="store_true", - help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.", + help="Use hierarchical all-to-all for MoE. Recommended to enable for multi-node training.", ) args = parser.parse_args() @@ -267,7 +267,7 @@ def main(): # ): # coordinator.print_on_master(f"Apply load balance") # apply_load_balance(model, optimizer) - # save ckeckpoint + # save checkpoint if (step + 1) % args.save_interval == 0: coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") save_checkpoint( diff --git a/applications/ColossalQA/colossalqa/data_loader/document_loader.py b/applications/ColossalQA/colossalqa/data_loader/document_loader.py index 4ddbf2b9d..ca0030621 100644 --- a/applications/ColossalQA/colossalqa/data_loader/document_loader.py +++ b/applications/ColossalQA/colossalqa/data_loader/document_loader.py @@ -52,7 +52,7 @@ class DocumentLoader: def load_data(self, path: str) -> None: """ Load data. Please refer to https://python.langchain.com/docs/modules/data_connection/document_loaders/ - for sepcific format requirements. + for specific format requirements. Args: path: path to a file To load files with glob path, here are some examples. diff --git a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py index 62aead66c..483f65e20 100644 --- a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py +++ b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py @@ -100,7 +100,7 @@ class ColossalCloudLLM(LLM): def text_completion(self, prompt, gen_config, auth_config): - # Complusory Parameters + # Required Parameters endpoint = auth_config.pop('endpoint') max_new_tokens = gen_config.pop('max_new_tokens') # Optional Parameters diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py index 0aa383e9d..4cc82c1a3 100644 --- a/applications/ColossalQA/colossalqa/local/llm.py +++ b/applications/ColossalQA/colossalqa/local/llm.py @@ -33,7 +33,7 @@ class ColossalAPI: def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> None: """ - Configurate model + Configure model """ if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances: return @@ -47,7 +47,7 @@ class ColossalAPI: self.model.load_state_dict(state_dict) self.model.to(torch.cuda.current_device()) - # Configurate tokenizer + # Configure tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.model.eval() @@ -87,7 +87,7 @@ class ColossalAPI: class VllmAPI: def __init__(self, host: str = "localhost", port: int = 8077) -> None: - # Configurate api for model served through web + # Configure api for model served through web self.host = host self.port = port self.url = f"http://{self.host}:{self.port}/generate" diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py index b23058d6d..6e77bb2ae 100644 --- a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py +++ b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py @@ -36,7 +36,7 @@ class UniversalRetrievalConversation: text_splitter_chunk_overlap=10, ) -> None: """ - Warpper for multilingual retrieval qa class (Chinese + English) + Wrapper for multilingual retrieval qa class (Chinese + English) Args: embedding_model_path: local or huggingface embedding model embedding_model_device: diff --git a/applications/ColossalQA/colossalqa/retriever.py b/applications/ColossalQA/colossalqa/retriever.py index c891cb613..870a73ca5 100644 --- a/applications/ColossalQA/colossalqa/retriever.py +++ b/applications/ColossalQA/colossalqa/retriever.py @@ -59,7 +59,7 @@ class CustomRetriever(BaseRetriever): Add documents to retriever Args: docs: the documents to add - cleanup: choose from "incremental" (update embeddings, skip existing embeddings) and "full" (destory and rebuild retriever) + cleanup: choose from "incremental" (update embeddings, skip existing embeddings) and "full" (destroy and rebuild retriever) mode: choose from "by source" (documents are grouped by source) and "merge" (documents are merged into one vector store) """ if cleanup == "full": diff --git a/applications/ColossalQA/colossalqa/utils.py b/applications/ColossalQA/colossalqa/utils.py index cd8c3e5ac..49d99014b 100644 --- a/applications/ColossalQA/colossalqa/utils.py +++ b/applications/ColossalQA/colossalqa/utils.py @@ -49,7 +49,7 @@ def destroy_sql_database(sql_engine: Union[Engine, str]) -> None: def detect_lang_naive(s): """ - Naive function for language detection, should be replaced by an independant layer + Naive function for language detection, should be replaced by an independent layer """ remove_nota = "[’·°–!\"#$%&'()*+,-./:;<=>?@,。?★、…【】()《》?“”‘’![\\]^_`{|}~]+" s = re.sub(remove_nota, "", s) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index eee3b505a..fdd8442f5 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -96,7 +96,7 @@ def set_openmoe_args( load_balance_beam_width (int, optional): Expert load balance search's beam width. Defaults to 8. load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4. enable_kernel (bool, optional): Use kernel optimization. Defaults to False. - enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False. + enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for multi-node training. Defaults to False. enable_hierarchical_alltoall (bool, optional): Use hierarchical alltoall for MoE. Defaults to False. """ moe_args = dict( diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index f3267b7c6..92f4e066a 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -190,13 +190,13 @@ def parse_args(): parser.add_argument( "--comm_overlap", action="store_true", - help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", + help="Use communication overlap for MoE. Recommended to enable for multi-node training.", ) # hierarchical all-to-all parser.add_argument( "--hierarchical_alltoall", action="store_true", - help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.", + help="Use hierarchical all-to-all for MoE. Recommended to enable for multi-node training.", ) args = parser.parse_args() @@ -366,7 +366,7 @@ def main(): ): coordinator.print_on_master(f"Apply load balance") apply_load_balance(model, optimizer) - # save ckeckpoint + # save checkpoint if (step + 1) % args.save_interval == 0: coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") booster.save_model(model, args.output_path, shard=True) From 341263df48bbef1174c41b6c4f5f6785f895b0d4 Mon Sep 17 00:00:00 2001 From: digger yu Date: Sun, 7 Apr 2024 19:04:58 +0800 Subject: [PATCH 42/43] [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548) --- examples/language/grok-1/inference.py | 4 ++-- examples/language/grok-1/inference_tp.py | 4 ++-- examples/language/grok-1/utils.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/language/grok-1/inference.py b/examples/language/grok-1/inference.py index faef7ae9d..58ba3872f 100644 --- a/examples/language/grok-1/inference.py +++ b/examples/language/grok-1/inference.py @@ -2,10 +2,10 @@ import time import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import get_defualt_parser, inference, print_output +from utils import get_default_parser, inference, print_output if __name__ == "__main__": - parser = get_defualt_parser() + parser = get_default_parser() args = parser.parse_args() start = time.time() torch.set_default_dtype(torch.bfloat16) diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py index cf05880dc..e10c4929c 100644 --- a/examples/language/grok-1/inference_tp.py +++ b/examples/language/grok-1/inference_tp.py @@ -3,7 +3,7 @@ import time import torch from grok1_policy import Grok1ForCausalLMPolicy from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import get_defualt_parser, inference, print_output +from utils import get_default_parser, inference, print_output import colossalai from colossalai.booster import Booster @@ -13,7 +13,7 @@ from colossalai.lazy import LazyInitContext from colossalai.utils import get_current_device if __name__ == "__main__": - parser = get_defualt_parser() + parser = get_default_parser() args = parser.parse_args() start = time.time() colossalai.launch_from_torch({}) diff --git a/examples/language/grok-1/utils.py b/examples/language/grok-1/utils.py index 7663127a5..29c86e411 100644 --- a/examples/language/grok-1/utils.py +++ b/examples/language/grok-1/utils.py @@ -33,7 +33,7 @@ def inference(model, tokenizer, text, **generate_kwargs): return outputs[0].tolist() -def get_defualt_parser(): +def get_default_parser(): parser = argparse.ArgumentParser() parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1") parser.add_argument("--tokenizer", type=str, default="tokenizer.model") From 641b1ee71a19e2337f3363620b228dd355835b04 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 8 Apr 2024 15:09:40 +0800 Subject: [PATCH 43/43] [devops] remove post commit ci (#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/pull_request_template.md | 1 + .github/workflows/post_commit.yml | 97 --------- .gitignore | 2 +- LICENSE | 2 +- .../colossal_llama2/model/init_model.py | 3 +- .../colossal_llama2/utils/ckpt_io.py | 2 +- .../utils/stream_chat_patch.py | 119 ++++++----- .../Colossal-LLaMA-2/docs/example_13b.md | 6 +- .../Colossal-LLaMA-2/docs/example_7b.md | 2 +- .../Colossal-LLaMA-2/hostfile.example | 2 +- .../Colossal-LLaMA-2/inference_example.py | 2 +- .../Colossal-LLaMA-2/requirements.txt | 1 - .../Colossal-LLaMA-2/stream_chat_example.py | 45 ++-- applications/Colossal-LLaMA-2/version.txt | 2 +- .../examples/training_scripts/train_dpo.py | 4 +- .../examples/dataset_evaluation/inference.py | 1 - applications/ColossalMoE/README.md | Bin 6474 -> 6475 bytes applications/ColossalMoE/infer.py | 1 - .../colossalqa/chain/retrieval_qa/base.py | 18 +- .../colossalqa/data_loader/document_loader.py | 2 +- .../data_loader/table_dataloader.py | 75 ++++--- .../colossalqa/local/colossalcloud_llm.py | 66 +++--- .../ColossalQA/colossalqa/local/llm.py | 1 - .../ColossalQA/colossalqa/prompt/prompt.py | 1 - .../ColossalQA/colossalqa/retriever.py | 8 +- .../ColossalQA/data/data_sample/companies.txt | 6 +- .../data/data_sample/companies_zh.txt | 2 +- .../data/data_sample/csv_organization_100.csv | 2 +- applications/ColossalQA/data/tests/64KB.json | 2 +- .../ColossalQA/data/tests/companies.csv | 202 +++++++++--------- applications/ColossalQA/data/tests/test.html | 48 ++--- applications/ColossalQA/data/tests/test.md | 4 +- applications/ColossalQA/data/tests/test.txt | 76 +++---- .../retrieval_conversation_universal.py | 33 +-- .../examples/webui_demo/RAG_ChatBot.py | 12 +- .../ColossalQA/examples/webui_demo/README.md | 2 +- .../ColossalQA/examples/webui_demo/config.py | 52 ++--- .../ColossalQA/examples/webui_demo/server.py | 29 +-- .../ColossalQA/examples/webui_demo/webui.py | 31 ++- applications/ColossalQA/pytest.ini | 2 +- .../ColossalQA/tests/test_document_loader.py | 14 +- .../ColossalQA/tests/test_retrieval_qa.py | 80 +++---- applications/ColossalQA/version.txt | 2 +- colossalai/__init__.py | 2 +- .../booster/plugin/torch_fsdp_plugin.py | 9 +- colossalai/checkpoint_io/utils.py | 3 +- colossalai/cluster/process_group_mesh.py | 25 ++- colossalai/inference/engine/modeling/llama.py | 49 +++-- .../quant/gptq/cai_gptq/cai_quant_linear.py | 202 +++++++++--------- colossalai/kernel/triton/context_attention.py | 133 ++++++++---- colossalai/kernel/triton/flash_decoding.py | 51 +++-- .../kernel/triton/llama_act_combine_kernel.py | 65 +++--- .../kernel/triton/token_attention_kernel.py | 20 +- .../tensor_parallel/modeling/llama.py | 4 +- colossalai/shardformer/layer/__init__.py | 2 +- colossalai/tensor/d_tensor/__init__.py | 4 +- colossalai/tensor/d_tensor/api.py | 7 +- colossalai/zero/low_level/_utils.py | 1 + .../train_vit_with_hybrid_parallelism.md | 2 +- examples/images/vit/vit_benchmark.py | 4 +- examples/language/data_utils.py | 2 +- examples/language/llama2/finetune.py | 4 +- examples/language/llama2/pretrain.py | 4 +- examples/language/openmoe/benchmark/utils.py | 6 +- examples/language/openmoe/infer.py | 14 +- .../openmoe/model/convert_openmoe_ckpt.py | 22 +- examples/language/opt/opt_train_demo.py | 4 +- extensions/cpu_adam/__init__.py | 3 +- extensions/layernorm/__init__.py | 2 +- extensions/moe/__init__.py | 2 +- extensions/optimizer/__init__.py | 2 +- extensions/softmax/__init__.py | 2 +- tests/kit/model_zoo/__init__.py | 44 ++-- tests/kit/model_zoo/registry.py | 2 +- tests/kit/model_zoo/transformers/chatglm2.py | 1 + ...st_hybrid_parallel_plugin_checkpoint_io.py | 8 +- .../test_torch_fsdp_checkpoint_io.py | 1 + tests/test_gptq/test_gptq_linear.py | 30 ++- tests/test_lazy/test_models.py | 2 +- tests/test_optimizer/test_nvme.py | 3 +- .../test_schedule/test_interleaved.py | 4 +- .../test_schedule/test_oneF_oneB.py | 4 +- 82 files changed, 849 insertions(+), 962 deletions(-) delete mode 100644 .github/workflows/post_commit.yml diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 0dc1262b0..d67cd9fcc 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -3,6 +3,7 @@ - [ ] I have created an issue for this PR for traceability - [ ] The title follows the standard format: `[doc/gemini/tensor/...]: A concise description` - [ ] I have added relevant tags if possible for us to better distinguish different PRs +- [ ] I have installed pre-commit: `pip install pre-commit && pre-commit install` ## 🚨 Issue number diff --git a/.github/workflows/post_commit.yml b/.github/workflows/post_commit.yml deleted file mode 100644 index 1bbc0d2f5..000000000 --- a/.github/workflows/post_commit.yml +++ /dev/null @@ -1,97 +0,0 @@ -name: post-commit - -on: - pull_request: - types: - - closed - -jobs: - # this job will run after a PR is merged to run pre-commit on any changed file - # so that the user does not need to learn pre-commit and pre-commit can still - # be auto-executed by the workflow - pre-commit: - runs-on: ubuntu-latest - if: github.event.pull_request.merged == true && github.repository == 'hpcaitech/ColossalAI' - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - ref: ${{ github.event.pull_request.head.sha }} - - # the PR branch and the hpcaitech/colossal-ai main branch - # must share a common commit, we need to locate that commit, - # which is the commit checked-out or forked when the PR branch is created - # such that we can look for files changed since that commit - - name: Locate base commit - id: locate-base-sha - run: | - curBranch=$(git rev-parse --abbrev-ref HEAD) - commonCommit=$(git merge-base origin/main $curBranch) - echo $commonCommit - echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT - - - name: Find the changed files - id: find-changed-files - uses: tj-actions/changed-files@v35 - with: - base_sha: ${{ steps.locate-base-sha.outputs.baseSHA }} - - - name: List all changed files - run: | - for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do - echo "$file was changed" - done - - # check out the main branch - - uses: actions/checkout@v2 - with: - ref: 'main' - - - uses: actions/setup-python@v3 - - - name: Cache pre-commit hooks - uses: actions/cache@v3 - with: - path: ~/.cache/pre-commit - key: ${{ runner.os }}-pre-commit-hooks - - - name: Set up pre-commit - run: | - pip install pre-commit - pre-commit install - - # run pre-commit on changed files - - name: Run Pre-commit - run: | - for file in ${{ steps.find-changed-files.outputs.all_changed_files }}; do - pre-commit run --files $file || true - done - - # create commit for pre-commit - # when all files are well formatted, there is no need to create a commit - # therefore, this step will produce an error, which should be allowed - - name: Create commits - id: commit - continue-on-error: true - run: | - git config --global user.name 'github-actions' - git config --global user.email 'github-actions@github.com' - git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }} - git add -A - git commit -am "[format] applied code formatting on changed files in pull request ${{ github.event.pull_request.number }}" - - # create pull request - - name: Create Pull Request - if: steps.commit.outcome == 'success' - id: cpr - uses: peter-evans/create-pull-request@v4 - with: - branch: pre-commit-${{ github.event.pull_request.number }} - title: "[format] applied code formatting on changed files in PR ${{ github.event.pull_request.number }}" - - - name: Enable Auto-merge for the New PR - if: steps.commit.outcome == 'success' - uses: peter-evans/enable-pull-request-automerge@v2 - with: - pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} - merge-method: squash diff --git a/.gitignore b/.gitignore index e539a2612..8bc74b4c8 100644 --- a/.gitignore +++ b/.gitignore @@ -162,4 +162,4 @@ coverage.xml # log, test files - ColossalChat applications/ColossalChat/logs -applications/ColossalChat/tests/logs \ No newline at end of file +applications/ColossalChat/tests/logs diff --git a/LICENSE b/LICENSE index bacb03e72..47197afe6 100644 --- a/LICENSE +++ b/LICENSE @@ -551,4 +551,4 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - THE SOFTWARE. + THE SOFTWARE. diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py index 67e487f43..f61291f35 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py @@ -8,11 +8,10 @@ import argparse import numpy as np import torch -from transformers import LlamaTokenizer, LlamaForCausalLM +from transformers import LlamaForCausalLM, LlamaTokenizer from colossalai.logging import get_dist_logger - logger = get_dist_logger() diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py index 85decf37d..05342ce41 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py @@ -10,8 +10,8 @@ import os from typing import Any, Dict, Tuple, Union import torch -from torch.optim.optimizer import Optimizer from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer from colossalai.booster import Booster from colossalai.cluster import DistCoordinator diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py index 8f8eecb18..44fa3678d 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/stream_chat_patch.py @@ -1,20 +1,19 @@ from copy import deepcopy -from typing import Optional, List, Dict, Tuple, Callable, Any +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import nn - from transformers import PreTrainedTokenizer -from transformers.utils import logging from transformers.generation.utils import GenerationConfig, LogitsProcessorList, StoppingCriteriaList - +from transformers.utils import logging + logger = logging.get_logger(__name__) def get_prompt_template( - input_query:str, - history:List[Dict]= None, - roles:list = ["", "Human", "Assistant"], + input_query: str, + history: List[Dict] = None, + roles: list = ["", "Human", "Assistant"], ) -> str: """ Generates a prompt template for chat models based on input and history. @@ -32,7 +31,7 @@ def get_prompt_template( new_history = [] else: new_history = deepcopy(history) - + new_history.append({"role": roles[1], "message": input_query.strip()}) new_history.append({"role": roles[2], "message": None}) @@ -48,22 +47,23 @@ def get_prompt_template( prompt += f"{role}: " return prompt + @torch.inference_mode() def streaming_chat( - model: Any, + model: Any, tokenizer: PreTrainedTokenizer, - input_query: str, - history: List[Dict] = None, - roles: list = ["", "Human", "Assistant"], - past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None, - temperature: float = 0.8, - top_p: float = 0.95, - top_k: int = 50, - do_sample: bool = True, + input_query: str, + history: List[Dict] = None, + roles: list = ["", "Human", "Assistant"], + past_key_values: Tuple[Tuple[torch.FloatTensor, Any], Any] = None, + temperature: float = 0.8, + top_p: float = 0.95, + top_k: int = 50, + do_sample: bool = True, length_penalty: float = 1.2, - max_new_tokens: int = 512, - logits_processor: LogitsProcessorList = None, - return_past_key_values: bool = False, + max_new_tokens: int = 512, + logits_processor: LogitsProcessorList = None, + return_past_key_values: bool = False, **kwargs, ): """ @@ -87,7 +87,7 @@ def streaming_chat( **kwargs: Additional keyword arguments for generation. Yields: - Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and + Tuple[str, List[Dict], Optional[Tuple[Tuple[torch.FloatTensor, Any], Any]]]: A tuple containing the generated response, updated history, and optionally the updated past key values if `return_past_key_values` is True. Ensures padding is on the left side for the tokenizer. @@ -97,32 +97,37 @@ def streaming_chat( history = [] if logits_processor is None: logits_processor = LogitsProcessorList() - + generation_kwargs = { - 'temperature': temperature, - 'top_p': top_p, - 'top_k': top_k, - 'do_sample': do_sample, - 'max_new_tokens': max_new_tokens, - 'length_penalty': length_penalty, - 'use_cache': True, - **kwargs + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "do_sample": do_sample, + "max_new_tokens": max_new_tokens, + "length_penalty": length_penalty, + "use_cache": True, + **kwargs, } prompt_str = get_prompt_template(input_query, history=history, roles=roles) - + eos_token_id = [tokenizer.eos_token_id] inputs = tokenizer(prompt_str, return_tensors="pt").to(model.device) history.append({"role": roles[1], "message": input_query.strip()}) history.append({"role": roles[2], "message": None}) - for outputs in stream_generate(model, **inputs, past_key_values=past_key_values, - eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, - **generation_kwargs): + for outputs in stream_generate( + model, + **inputs, + past_key_values=past_key_values, + eos_token_id=eos_token_id, + return_past_key_values=return_past_key_values, + **generation_kwargs, + ): if return_past_key_values: outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1] response = tokenizer.decode(outputs) history[-1]["message"] = response.strip() @@ -130,30 +135,30 @@ def streaming_chat( yield response, history, past_key_values else: yield response, history - + @torch.inference_mode() def stream_generate( - model: Any, - input_ids: torch.Tensor, + model: Any, + input_ids: torch.Tensor, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values: bool = False, + return_past_key_values: bool = False, **kwargs, ): """ Generates sequences of token ids using the specified model and generation parameters. Adapted from https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py - + Args: model (Any): The model used for generating sequences of token ids. - input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder. + input_ids (torch.Tensor): The sequence used as a prompt for the generation or as model inputs to the encoder. generation_config (Optional[GenerationConfig]): The generation configuration to be used as base parametrization for the generation call. logits_processor (Optional[LogitsProcessorList]): Custom logits processors that complement the default logits processors built from arguments and generation config. - stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments + stopping_criteria (Optional[StoppingCriteriaList]): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to constrain token generation. return_past_key_values (bool): Whether to return past key values for further incremental decoding, defaults to False. @@ -169,7 +174,7 @@ def stream_generate( generation_config = model.generation_config generation_config = deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) - + eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -177,25 +182,25 @@ def stream_generate( if generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + input_ids_len - + if input_ids_len >= generation_config.max_length: input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids" logger.warning( - f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) + f"Input length of {input_ids_string} is {input_ids_len}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - + # prepare distribution pre_processing samplers logits_processor = model._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_len, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) + generation_config=generation_config, + input_ids_seq_length=input_ids_len, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) # prepare stopping criteria stopping_criteria = model._get_stopping_criteria( @@ -205,7 +210,7 @@ def stream_generate( logits_warper = model._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None - + while True: model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token @@ -244,4 +249,4 @@ def stream_generate( yield input_ids # stop when each sentence is finished, or if exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break \ No newline at end of file + break diff --git a/applications/Colossal-LLaMA-2/docs/example_13b.md b/applications/Colossal-LLaMA-2/docs/example_13b.md index 3854aaa6a..5305a999b 100644 --- a/applications/Colossal-LLaMA-2/docs/example_13b.md +++ b/applications/Colossal-LLaMA-2/docs/example_13b.md @@ -43,7 +43,7 @@ if __name__ == '__main__': model.to(device) tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/Colossal-LLaMA-2-13b-base") - + question = "xxx" # Your question. prompt = (f"请对下面的问题,给出准确、可靠、安全的回答\n" f"{question}" @@ -111,7 +111,7 @@ The Colossal-AI team has conducted comprehensive evaluation for safety and value | 整车热管理主要包括什么内容 | 整车热管理主要包括:1、发动机冷却系统;2、散热系统;3、空调冷却系统;4、采暖系统;5、动力转向器冷却系统;6、润滑油冷却系统;7、变速箱冷却系统;8、燃油冷却系统;9、冷却系统排放液系统。 | | 简述下汽车空调-热泵循环系统的原理 | 汽车空调系统主要有三种形式,其中热泵循环系统是其中的一种。其原理是利用发动机排放的热量,通过压缩机将冷媒压缩成高温高压的状态,然后通过冷却器释放到空气中,吸收周围的热量,再通过膨胀阀释放到蒸发器中,蒸发器在压缩机的作用下将冷媒压缩成高温高压的状态,再次释放到空气中,从而实现制冷的效果。 | -#### Finance +#### Finance | Question |
                  Colossal-LLaMA-2-13b-base
                  | | :------: | :----------------------- | | 如何看待当前提前还贷热 | 提前还贷是一种常见的还贷方式,在市场中存在着许多的因素和影响,需要具体分析当前的经济形势和人们心理状况的影响。提前还贷可以帮助借款人早日摆脱债务负担,但也存在一些风险和不确定性。从经济形势来看,当前利率下探、还款压力减轻等因素可能会导致一部分借款人提前还贷。此外,一些购房者可能会受到市场不景气的影响,担心未来房产价值会下降,因此提前还贷以避免资产贬值。从借款人心理状况来看,提前还贷也可能会由于借款人对未来经济预期的不确定性而产生压力和不满。因此,提前还贷并不能一概而论,需要根据具体情况进行分析。如果提前还贷对借款人的经济情况和未来规划造成不利影响,建议谨慎考虑是否要提前还贷。| @@ -150,4 +150,4 @@ The Colossal-AI team has conducted comprehensive evaluation for safety and value ## Conclusion -The Colossal-AI team's advanced 13B model, compared to the 7B version, features a refined data structure categorizing information into informative, functional, and memory replay data. Informative data is intricately subdivided into major categories, each further segmented for precise control. Concurrently, data scale across domains is expanded. Tailored enhancements meet community demands for large model capabilities in natural language processing tasks, ensuring proficiency during pre-training and cost-effective fine-tuning. Addressing security and values concerns, multidimensional controls are implemented, securing the baseline model and aligning it with correct values. \ No newline at end of file +The Colossal-AI team's advanced 13B model, compared to the 7B version, features a refined data structure categorizing information into informative, functional, and memory replay data. Informative data is intricately subdivided into major categories, each further segmented for precise control. Concurrently, data scale across domains is expanded. Tailored enhancements meet community demands for large model capabilities in natural language processing tasks, ensuring proficiency during pre-training and cost-effective fine-tuning. Addressing security and values concerns, multidimensional controls are implemented, securing the baseline model and aligning it with correct values. diff --git a/applications/Colossal-LLaMA-2/docs/example_7b.md b/applications/Colossal-LLaMA-2/docs/example_7b.md index d889ab416..d833d2821 100644 --- a/applications/Colossal-LLaMA-2/docs/example_7b.md +++ b/applications/Colossal-LLaMA-2/docs/example_7b.md @@ -242,4 +242,4 @@ To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model, ## Conclusion In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above. -Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements. \ No newline at end of file +Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements. diff --git a/applications/Colossal-LLaMA-2/hostfile.example b/applications/Colossal-LLaMA-2/hostfile.example index 82948648c..cfaaa0ef5 100644 --- a/applications/Colossal-LLaMA-2/hostfile.example +++ b/applications/Colossal-LLaMA-2/hostfile.example @@ -1,2 +1,2 @@ hostname1 -hostname2 \ No newline at end of file +hostname2 diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA-2/inference_example.py index 63ce91e50..8d301616d 100644 --- a/applications/Colossal-LLaMA-2/inference_example.py +++ b/applications/Colossal-LLaMA-2/inference_example.py @@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs): model.to(device) try: - tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left') + tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") except OSError: raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.") diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA-2/requirements.txt index 34afaf7e5..5cdb8e7f3 100644 --- a/applications/Colossal-LLaMA-2/requirements.txt +++ b/applications/Colossal-LLaMA-2/requirements.txt @@ -12,4 +12,3 @@ flash-attn>=2.0.0,<=2.0.5 tqdm sentencepiece==0.1.99 protobuf<=3.20.0 - diff --git a/applications/Colossal-LLaMA-2/stream_chat_example.py b/applications/Colossal-LLaMA-2/stream_chat_example.py index 3e45c690f..4c0d1fe2a 100644 --- a/applications/Colossal-LLaMA-2/stream_chat_example.py +++ b/applications/Colossal-LLaMA-2/stream_chat_example.py @@ -1,11 +1,11 @@ -import os import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM from colossal_llama2.utils.stream_chat_patch import streaming_chat +from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + def main(args): model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval() tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) @@ -27,29 +27,34 @@ def main(args): print(f"\n{roles[2]}: ", end="") gen_len = 0 for response, history, past_key_values in streaming_chat( - model, tokenizer, input_query, history=history, roles=roles, - temperature = args.temperature, - top_p = args.top_p, - top_k = args.top_k, - do_sample = args.do_sample, - length_penalty = args.length_penalty, - max_new_tokens = args.max_new_tokens, + model, + tokenizer, + input_query, + history=history, + roles=roles, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + do_sample=args.do_sample, + length_penalty=args.length_penalty, + max_new_tokens=args.max_new_tokens, past_key_values=past_key_values, - return_past_key_values=True): - + return_past_key_values=True, + ): output = response[gen_len:] print(output, end="", flush=True) gen_len = len(response) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model_path', type=str, default=None, help="path to chat version model") - parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer") - parser.add_argument('--temperature', type=float, default=0.8, help="set temperature") - parser.add_argument('--top_p', type=float, default=0.95, help="set top p value") - parser.add_argument('--top_k', type=int, default=50, help="set top k value") - parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not") - parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty") - parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens") + parser.add_argument("--model_path", type=str, default=None, help="path to chat version model") + parser.add_argument("--tokenizer_path", type=str, default=None, help="path to chat version tokenizer") + parser.add_argument("--temperature", type=float, default=0.8, help="set temperature") + parser.add_argument("--top_p", type=float, default=0.95, help="set top p value") + parser.add_argument("--top_k", type=int, default=50, help="set top k value") + parser.add_argument("--do_sample", type=bool, default=True, help="whether turn on do_sample or not") + parser.add_argument("--length_penalty", type=float, default=1.2, help="set length penalty") + parser.add_argument("--max_new_tokens", type=int, default=512, help="set max new tokens") args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/applications/Colossal-LLaMA-2/version.txt b/applications/Colossal-LLaMA-2/version.txt index 8a9ecc2ea..8acdd82b7 100644 --- a/applications/Colossal-LLaMA-2/version.txt +++ b/applications/Colossal-LLaMA-2/version.txt @@ -1 +1 @@ -0.0.1 \ No newline at end of file +0.0.1 diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py index aae034683..b9287eb1a 100755 --- a/applications/ColossalChat/examples/training_scripts/train_dpo.py +++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py @@ -20,13 +20,13 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator +from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device -from colossalai.logging import get_dist_logger logger = get_dist_logger() + def train(args): # check lora compatibility if "gemini" in args.plugin and args.lora_rank > 0: diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index a340f3bfd..13bbb12b6 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -3,7 +3,6 @@ import copy import os from typing import Dict, List -import torch import torch.distributed as dist from colossal_eval import dataset, models, utils diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md index be50a8f9f25111edabd7c13660bb976f9cc1f252..ba864d1dff8b2c52b3a5c45261f37586ecdb5bc1 100644 GIT binary patch delta 9 QcmX?QblPZxmn0(>02EsT7ytkO delta 7 OcmX?YbjoOhmm~lVHv None: - ''' + """ + + def __init__(self, files: str, sql_path: str = "sqlite:///mydatabase.db", verbose=False, **kwargs) -> None: + """ Args: files: list of files (list[file path, name]) sql_path: how to serve the sql database - **kwargs: keyword type arguments, useful for certain document types - ''' + **kwargs: keyword type arguments, useful for certain document types + """ self.data = {} self.verbose = verbose self.sql_path = sql_path self.kwargs = kwargs self.sql_engine = create_engine(self.sql_path) drop_table(self.sql_engine) - + self.sql_engine = create_engine(self.sql_path) for item in files: path = item[0] @@ -42,68 +43,68 @@ class TableLoader: raise FileNotFoundError(f"{path} doesn't exists") if not any([path.endswith(i) for i in SUPPORTED_DATA_FORMAT]): raise TypeError(f"{path} not supported. Supported type {SUPPORTED_DATA_FORMAT}") - + logger.info("loading data", verbose=self.verbose) self.load_data(path) logger.info("data loaded", verbose=self.verbose) self.to_sql(path, dataset_name) def load_data(self, path): - ''' + """ Load data and serve the data as sql database. Data must be in pandas format - ''' + """ files = [] # Handle glob expression try: files = glob.glob(path) except Exception as e: logger.error(e) - if len(files)==0: + if len(files) == 0: raise ValueError("Unsupported file/directory format. For directories, please use glob expression") - elif len(files)==1: + elif len(files) == 1: path = files[0] else: for file in files: self.load_data(file) - if path.endswith('.csv'): + if path.endswith(".csv"): # Load csv self.data[path] = pd.read_csv(path) - elif path.endswith('.xlsx') or path.endswith('.xls'): + elif path.endswith(".xlsx") or path.endswith(".xls"): # Load excel self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed - elif path.endswith('.json'): + elif path.endswith(".json"): # Load json self.data[path] = pd.read_json(path) - elif path.endswith('.html'): + elif path.endswith(".html"): # Load html html_tables = pd.read_html(path) # Choose the desired table from the list of DataFrame objects self.data[path] = html_tables[0] # You may need to adjust this index - elif path.endswith('.h5') or path.endswith('.hdf5'): + elif path.endswith(".h5") or path.endswith(".hdf5"): # Load h5 - self.data[path] = pd.read_hdf(path, key=self.kwargs.get('key', 'data')) # You can adjust the key as needed - elif path.endswith('.parquet'): + self.data[path] = pd.read_hdf(path, key=self.kwargs.get("key", "data")) # You can adjust the key as needed + elif path.endswith(".parquet"): # Load parquet - self.data[path] = pd.read_parquet(path, engine='fastparquet') - elif path.endswith('.feather'): + self.data[path] = pd.read_parquet(path, engine="fastparquet") + elif path.endswith(".feather"): # Load feather self.data[path] = pd.read_feather(path) - elif path.endswith('.dta'): + elif path.endswith(".dta"): # Load dta self.data[path] = pd.read_stata(path) else: raise ValueError("Unsupported file format") - + def to_sql(self, path, table_name): - ''' + """ Serve the data as sql database. - ''' - self.data[path].to_sql(table_name, con=self.sql_engine, if_exists='replace', index=False) + """ + self.data[path].to_sql(table_name, con=self.sql_engine, if_exists="replace", index=False) logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose) return self.sql_path - + def get_sql_path(self): return self.sql_path @@ -113,7 +114,3 @@ class TableLoader: self.sql_engine.dispose() del self.data del self.sql_engine - - - - diff --git a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py index 483f65e20..362977869 100644 --- a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py +++ b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py @@ -21,7 +21,7 @@ print(resp) # super-heavyweight awesome-natured yawning Australian creature! """ import json -from typing import Any, List, Mapping, Optional +from typing import Any, Mapping import requests from langchain.llms.base import LLM @@ -31,31 +31,31 @@ from langchain.utils import get_from_dict_or_env class ColossalCloudLLM(LLM): """ A custom LLM class that integrates LLMs running on the ColossalCloud Platform - + """ - n: int - gen_config: dict = None + + n: int + gen_config: dict = None auth_config: dict = None - valid_gen_para: list = ['max_new_tokens', 'top_k', - 'top_p', 'temperature', 'repetition_penalty'] + valid_gen_para: list = ["max_new_tokens", "top_k", "top_p", "temperature", "repetition_penalty"] def __init__(self, gen_config=None, **kwargs): """ Args: gen_config: config for generation, max_new_tokens: 50 by default - top_k: (1, vocab_size) + top_k: (1, vocab_size) top_p: (0, 1) if not None - temperature: (0, inf) if not None + temperature: (0, inf) if not None repetition_penalty: (1, inf) if not None """ super(ColossalCloudLLM, self).__init__(**kwargs) - if gen_config is None: - self.gen_config = {"max_new_tokens": 50} - else: + if gen_config is None: + self.gen_config = {"max_new_tokens": 50} + else: assert "max_new_tokens" in gen_config, "max_new_tokens is a compulsory key in the gen config" self.gen_config = gen_config - + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" @@ -63,17 +63,17 @@ class ColossalCloudLLM(LLM): @property def _llm_type(self) -> str: - return 'ColossalCloudLLM' - + return "ColossalCloudLLM" + def set_auth_config(self, **kwargs): url = get_from_dict_or_env(kwargs, "url", "URL") host = get_from_dict_or_env(kwargs, "host", "HOST") - + auth_config = {} - auth_config['endpoint'] = url - auth_config['Host'] = host + auth_config["endpoint"] = url + auth_config["Host"] = host self.auth_config = auth_config - + def _call(self, prompt: str, stop=None, **kwargs: Any) -> str: """ Args: @@ -81,15 +81,17 @@ class ColossalCloudLLM(LLM): stop: A list of strings to stop generation when encountered Returns: - The string generated by the model + The string generated by the model """ # Update the generation arguments for key, value in kwargs.items(): if key not in self.valid_gen_para: - raise KeyError(f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}") + raise KeyError( + f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}" + ) if key in self.gen_config: self.gen_config[key] = value - + resp_text = self.text_completion(prompt, self.gen_config, self.auth_config) # TODO: This may cause excessive tokens count if stop is not None: @@ -97,29 +99,19 @@ class ColossalCloudLLM(LLM): if stopping_words in resp_text: resp_text = resp_text.split(stopping_words)[0] return resp_text - def text_completion(self, prompt, gen_config, auth_config): # Required Parameters - endpoint = auth_config.pop('endpoint') - max_new_tokens = gen_config.pop('max_new_tokens') + endpoint = auth_config.pop("endpoint") + max_new_tokens = gen_config.pop("max_new_tokens") # Optional Parameters - optional_params = ['top_k', 'top_p', 'temperature', 'repetition_penalty'] # Self.optional + optional_params = ["top_k", "top_p", "temperature", "repetition_penalty"] # Self.optional gen_config = {key: gen_config[key] for key in optional_params if key in gen_config} # Define the data payload - data = { - "max_new_tokens": max_new_tokens, - "history": [ - {"instruction": prompt, "response": ""} - ], - **gen_config - } - headers = { - "Content-Type": "application/json", - **auth_config # 'Host', - } + data = {"max_new_tokens": max_new_tokens, "history": [{"instruction": prompt, "response": ""}], **gen_config} + headers = {"Content-Type": "application/json", **auth_config} # 'Host', # Make the POST request response = requests.post(endpoint, headers=headers, data=json.dumps(data)) - response.raise_for_status() # raise error if return code is not 200(success) + response.raise_for_status() # raise error if return code is not 200(success) # Check the response return response.text diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py index 4cc82c1a3..30a456c3d 100644 --- a/applications/ColossalQA/colossalqa/local/llm.py +++ b/applications/ColossalQA/colossalqa/local/llm.py @@ -193,4 +193,3 @@ class VllmLLM(LLM): def _identifying_params(self) -> Mapping[str, int]: """Get the identifying parameters.""" return {"n": self.n} - diff --git a/applications/ColossalQA/colossalqa/prompt/prompt.py b/applications/ColossalQA/colossalqa/prompt/prompt.py index d62249ba9..a8aeded3a 100644 --- a/applications/ColossalQA/colossalqa/prompt/prompt.py +++ b/applications/ColossalQA/colossalqa/prompt/prompt.py @@ -4,7 +4,6 @@ All custom prompt templates are defined here. from langchain.prompts.prompt import PromptTemplate - # Below are Chinese retrieval qa prompts _CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。 diff --git a/applications/ColossalQA/colossalqa/retriever.py b/applications/ColossalQA/colossalqa/retriever.py index 870a73ca5..6a0c69859 100644 --- a/applications/ColossalQA/colossalqa/retriever.py +++ b/applications/ColossalQA/colossalqa/retriever.py @@ -99,13 +99,7 @@ class CustomRetriever(BaseRetriever): def clear_documents(self): """Clear all document vectors from database""" for source in self.vector_stores: - index( - [], - self.record_managers[source], - self.vector_stores[source], - cleanup="full", - source_id_key="source" - ) + index([], self.record_managers[source], self.vector_stores[source], cleanup="full", source_id_key="source") self.vector_stores = {} self.sql_index_database = {} self.record_managers = {} diff --git a/applications/ColossalQA/data/data_sample/companies.txt b/applications/ColossalQA/data/data_sample/companies.txt index 05c6148f1..4b297e6ff 100644 --- a/applications/ColossalQA/data/data_sample/companies.txt +++ b/applications/ColossalQA/data/data_sample/companies.txt @@ -1,6 +1,6 @@ Overview The Straits Times is the English flagship daily of SPH Media, one of the leading media companies in Asia. Launched on July 15, 1845, its comprehensive coverage of news from home and around the world makes The Straits Times the most-read newspaper in Singapore. Quality news, in-depth analyses, impactful commentaries and breaking stories are packaged to give readers riveting accounts of events in Singapore, the region, and beyond. The most read newspaper in Singapore, both in terms of print and digital, it reaches 1.33 million people every day. The Straits Times'​ key strength is in its world class coverage of news outside Singapore. With 20 bureaus in major cities around the world, The Straits Times correspondents bring world news to readers on a Singapore platter, helping readers to appreciate world events from a Singaporean perspective. Website http://www.straitstimes.com Phone 63196319Phone number is 63196319 Industry Newspaper Publishing Company size 1,001-5,000 employees 183 on LinkedIn Includes members with current employer listed as The Straits Times, including part-time roles. Headquarters Singapore, Singapore Founded 1845 Specialties News and Digital media -About With over 500 properties worldwide, Marriott Hotels has reimagined hospitality to exceed the expectations of business, group, and leisure travelers. +About With over 500 properties worldwide, Marriott Hotels has reimagined hospitality to exceed the expectations of business, group, and leisure travelers. Marriott Hotels, Marriott’s flagship brand of quality-tier, full-service hotels and resorts, provides consistent, dependable and genuinely caring experiences to guests on their terms. Marriott is a brilliant host to guests who effortlessly blend life and work, and who are inspired by how modern travel enhances them both. Our hotels offer warm, professional service; sophisticated yet functional guest room design; lobby spaces that facilitate working, dining and socializing; restaurants and bars serving international cuisine prepared simply and from the freshest ingredients; meeting and event spaces and services that are gold standard; and expansive, 24-hour fitness facilities. -Overview AERCO International, Inc. is a recognized leader in delivering cost-effective, condensing commercial boilers, high-efficiency water heaters across a variety of markets including education, lodging, government, office buildings, healthcare, industrial and multifamily housing. AERCO's system design approach provides customer-specific solutions that deliver superior building performance at a lower operating cost while assuring uptime reliability. When AERCO was founded in 1949, it introduced a revolutionary design for an indirect-fired water heater that heated water on demand, and without storage, at a controlled temperature. This innovation became today's standard for water heaters, maximizing the recovery of latent heat energy and significantly increasing operating efficiency. AERCO continued to innovate and in 1988, introduced the first condensing and fully modulating boiler and water heater to the commercial market. The modulating capability of these products, still unsurpassed more than 25 years later, matches the equipment's output to real-time heating demand, ensuring the units draw no more fuel to operate than is absolutely necessary. This not only saves precious energy, but also ensures money doesn't needlessly disappear "up the stack."​ AERCO differentiates itself through a solution-based model, leveraging decades of engineering experience and industry application expertise to understand each customer’s unique needs. By partnering directly with customers and end-users to understand their project-specific requirements, AERCO provides tailored application solutions that are comprised of original product technologies including high efficiency condensing products, compact footprints, high turndown ratios, unique fuel delivery, leading control systems and proprietary design elements that combine to deliver up to 99% efficiency. Website http://www.aerco.com Phone 845-580-8000Phone number is 845-580-8000 Industry Industrial Machinery Manufacturing Company size 51-200 employees 119 on LinkedIn Includes members with current employer listed as AERCO International, Inc., including part-time roles. Headquarters Blauvelt, NY Founded 1949 Specialties Leading manufacturer of condensing boilers, water heating and energy recovery products and The originator of semi-instantaneous water heating +Overview AERCO International, Inc. is a recognized leader in delivering cost-effective, condensing commercial boilers, high-efficiency water heaters across a variety of markets including education, lodging, government, office buildings, healthcare, industrial and multifamily housing. AERCO's system design approach provides customer-specific solutions that deliver superior building performance at a lower operating cost while assuring uptime reliability. When AERCO was founded in 1949, it introduced a revolutionary design for an indirect-fired water heater that heated water on demand, and without storage, at a controlled temperature. This innovation became today's standard for water heaters, maximizing the recovery of latent heat energy and significantly increasing operating efficiency. AERCO continued to innovate and in 1988, introduced the first condensing and fully modulating boiler and water heater to the commercial market. The modulating capability of these products, still unsurpassed more than 25 years later, matches the equipment's output to real-time heating demand, ensuring the units draw no more fuel to operate than is absolutely necessary. This not only saves precious energy, but also ensures money doesn't needlessly disappear "up the stack."​ AERCO differentiates itself through a solution-based model, leveraging decades of engineering experience and industry application expertise to understand each customer’s unique needs. By partnering directly with customers and end-users to understand their project-specific requirements, AERCO provides tailored application solutions that are comprised of original product technologies including high efficiency condensing products, compact footprints, high turndown ratios, unique fuel delivery, leading control systems and proprietary design elements that combine to deliver up to 99% efficiency. Website http://www.aerco.com Phone 845-580-8000Phone number is 845-580-8000 Industry Industrial Machinery Manufacturing Company size 51-200 employees 119 on LinkedIn Includes members with current employer listed as AERCO International, Inc., including part-time roles. Headquarters Blauvelt, NY Founded 1949 Specialties Leading manufacturer of condensing boilers, water heating and energy recovery products and The originator of semi-instantaneous water heating Prince PLC: Overview We are a global leader of quality water solutions for residential, industrial, municipal, and commercial settings. Our family of brands offers one of the most varied product lines in the world, with world-class, water-related solutions focused on: • Plumbing & Flow Control • Water Quality & Conditioning • Water Reuse & Drainage • HVAC • Municipal Waterworks Strategic Goals Watts Water is traded on the New York Stock Exchange under the symbol “WTS.” As a public company, growing shareholder value is critical. To that end, we focus on a five-part Global Strategy: Growth, Commercial Excellence, Operational Excellence, “One Watts Water,” and a Talent & Performance Culture. Follow us on all social media platforms @WattsWater Website http://www.watts.com/ Industry Wholesale Building Materials Company size 5,001-10,000 employees 2,248 on LinkedIn Includes members with current employer listed as Watts Water Technologies, including part-time roles. Headquarters North Andover, MA Specialties Plumbing, HVAC, Water Quality, Gas, Conditioning, Waterworks, and Drainage -About Courtyard Hotels is Marriott International’s largest hotel brand, with more than 1,100 hotels in over 50 countries worldwide. So, no matter where passion takes you, you’ll find us there to help you follow it. Proud members of Marriott Bonvoy. \ No newline at end of file +About Courtyard Hotels is Marriott International’s largest hotel brand, with more than 1,100 hotels in over 50 countries worldwide. So, no matter where passion takes you, you’ll find us there to help you follow it. Proud members of Marriott Bonvoy. diff --git a/applications/ColossalQA/data/data_sample/companies_zh.txt b/applications/ColossalQA/data/data_sample/companies_zh.txt index a67a93590..511873bde 100644 --- a/applications/ColossalQA/data/data_sample/companies_zh.txt +++ b/applications/ColossalQA/data/data_sample/companies_zh.txt @@ -3,4 +3,4 @@ 万豪酒店(Marriott Hotels)是万豪旗下优质、全方位服务酒店和度假村的旗舰品牌,为客人提供始终如一、可靠和真诚关怀的体验。万豪是一个出色的主人,客人可以轻松地将生活和工作融合在一起,并受到现代旅行如何增强两者的启发。我们的酒店提供热情、专业的服务;精致而实用的客房设计;大堂空间,方便工作、餐饮和社交;餐厅和酒吧提供简单的国际美食和最新鲜的食材;会议及活动场地及服务均属黄金标准;还有宽敞的24小时健身设施。 AERCO International, Inc.是公认的领导者,为教育、住宿、政府、办公楼、医疗保健、工业和多户住宅等各种市场提供具有成本效益的冷凝商用锅炉和高效热水器。AERCO的系统设计方法为客户提供特定的解决方案,以较低的运营成本提供卓越的建筑性能,同时确保正常运行时间的可靠性。AERCO成立于1949年,它推出了一种革命性的设计,用于间接燃烧热水器,在控制温度下按需加热水,而无需储存。这一创新成为当今热水器的标准,最大限度地回收潜热能量,显著提高运行效率。AERCO不断创新,并于1988年向商业市场推出了第一台冷凝和全调制锅炉和热水器。这些产品的调制能力,在超过25年后仍然无与伦比,使设备的输出与实时加热需求相匹配,确保机组不会消耗更多的燃料来运行,除非绝对必要。这不仅节省了宝贵的能源,还确保了钱不会不必要地消失在“堆栈”上。AERCO通过基于解决方案的模式脱颖而出,利用数十年的工程经验和行业应用专业知识来了解每个客户的独特需求。通过与客户和最终用户直接合作,了解他们的项目具体要求,AERCO提供量身定制的应用解决方案,这些解决方案由原创产品技术组成,包括高效冷凝产品,紧凑的足迹,高降压比,独特的燃料输送,领先的控制系统和专有设计元素,结合起来可提供高达99%的效率。网址http://www.aerco.com电话845-580- 8000电话号码845-580-8000工业工业机械制造公司规模51-200名员工LinkedIn上包括当前雇主AERCO International, Inc的成员,包括兼职职位。总部成立于1949年,纽约州布劳维尔特,专长:冷凝锅炉,水加热和能源回收产品的领先制造商,半瞬时水加热的鼻祖 Prince PLC:概述Prince PLC是为住宅、工业、市政和商业环境提供优质水解决方案的全球领导者。我们的品牌家族提供世界上最多样化的产品线之一,拥有世界级的水相关解决方案,专注于:•管道和流量控制•水质和调理•水再利用和排水•hvac•市政水务战略目标瓦茨水务在纽约证券交易所上市,代码为“WTS”。作为一家上市公司,股东价值的增长至关重要。为此,我们将重点放在五部分全球战略上:增长、卓越商业、卓越运营、“一瓦茨水”以及人才与绩效文化。在所有社交媒体平台关注我们@WattsWater网站http://www.watts.com/行业批发建材公司规模5,001-10,000名员工领英2,248名包括目前雇主为WattsWater Technologies的成员,包括兼职职位。总部北安多弗,MA专业管道,暖通空调,水质,气体,空调,自来水厂和排水 -万怡酒店是万豪国际最大的酒店品牌,在全球50多个国家拥有1100多家酒店。所以,无论你的激情带你去哪里,你都会发现我们会帮助你追随它。万豪酒店的骄傲会员。 \ No newline at end of file +万怡酒店是万豪国际最大的酒店品牌,在全球50多个国家拥有1100多家酒店。所以,无论你的激情带你去哪里,你都会发现我们会帮助你追随它。万豪酒店的骄傲会员。 diff --git a/applications/ColossalQA/data/data_sample/csv_organization_100.csv b/applications/ColossalQA/data/data_sample/csv_organization_100.csv index dbe97d5fd..5d88bfb94 100644 --- a/applications/ColossalQA/data/data_sample/csv_organization_100.csv +++ b/applications/ColossalQA/data/data_sample/csv_organization_100.csv @@ -98,4 +98,4 @@ Index,Organization Id,Company Name,Website,Country,Description,Founded,Industry, 97,BA6Cd9Dae2Efd62,Good Ltd,http://duffy.com/,Anguilla,Reverse-engineered composite moratorium,1971,Consumer Services,4292 98,E7df80C60Abd7f9,Clements-Espinoza,http://www.flowers.net/,Falkland Islands (Malvinas),Progressive modular hub,1991,Broadcast Media,236 99,AFc285dbE2fEd24,Mendez Inc,https://www.burke.net/,Kyrgyz Republic,User-friendly exuding migration,1993,Education Management,339 -100,e9eB5A60Cef8354,Watkins-Kaiser,http://www.herring.com/,Togo,Synergistic background access,2009,Financial Services,2785 \ No newline at end of file +100,e9eB5A60Cef8354,Watkins-Kaiser,http://www.herring.com/,Togo,Synergistic background access,2009,Financial Services,2785 diff --git a/applications/ColossalQA/data/tests/64KB.json b/applications/ColossalQA/data/tests/64KB.json index 99278dc5c..41e1aeedf 100644 --- a/applications/ColossalQA/data/tests/64KB.json +++ b/applications/ColossalQA/data/tests/64KB.json @@ -4,4 +4,4 @@ {"content":"Aliquam sollicitudin ante ligula, eget malesuada nibh efficitur et. Pellentesque massa sem, scelerisque sit amet odio id, cursus tempor urna. Etiam congue dignissim volutpat. Vestibulum pharetra libero et velit gravida euismod."} ], "name":"player" -} \ No newline at end of file +} diff --git a/applications/ColossalQA/data/tests/companies.csv b/applications/ColossalQA/data/tests/companies.csv index 93dcac9f3..a111992d7 100644 --- a/applications/ColossalQA/data/tests/companies.csv +++ b/applications/ColossalQA/data/tests/companies.csv @@ -1,101 +1,101 @@ -Index,Organization Id,Name,Website,Country,Description,Founded,Industry,Number of employees -1,FAB0d41d5b5d22c,Ferrell LLC,https://price.net/,Papua New Guinea,Horizontal empowering knowledgebase,1990,Plastics,3498 -2,6A7EdDEA9FaDC52,"Mckinney, Riley and Day",http://www.hall-buchanan.info/,Finland,User-centric system-worthy leverage,2015,Glass / Ceramics / Concrete,4952 -3,0bFED1ADAE4bcC1,Hester Ltd,http://sullivan-reed.com/,China,Switchable scalable moratorium,1971,Public Safety,5287 -4,2bFC1Be8a4ce42f,Holder-Sellers,https://becker.com/,Turkmenistan,De-engineered systemic artificial intelligence,2004,Automotive,921 -5,9eE8A6a4Eb96C24,Mayer Group,http://www.brewer.com/,Mauritius,Synchronized needs-based challenge,1991,Transportation,7870 -6,cC757116fe1C085,Henry-Thompson,http://morse.net/,Bahamas,Face-to-face well-modulated customer loyalty,1992,Primary / Secondary Education,4914 -7,219233e8aFF1BC3,Hansen-Everett,https://www.kidd.org/,Pakistan,Seamless disintermediate collaboration,2018,Publishing Industry,7832 -8,ccc93DCF81a31CD,Mcintosh-Mora,https://www.brooks.com/,Heard Island and McDonald Islands,Centralized attitude-oriented capability,1970,Import / Export,4389 -9,0B4F93aA06ED03e,Carr Inc,http://ross.com/,Kuwait,Distributed impactful customer loyalty,1996,Plastics,8167 -10,738b5aDe6B1C6A5,Gaines Inc,http://sandoval-hooper.com/,Uzbekistan,Multi-lateral scalable protocol,1997,Outsourcing / Offshoring,9698 -11,AE61b8Ffebbc476,Kidd Group,http://www.lyons.com/,Bouvet Island (Bouvetoya),Proactive foreground paradigm,2001,Primary / Secondary Education,7473 -12,eb3B7D06cCdD609,Crane-Clarke,https://www.sandoval.com/,Denmark,Front-line clear-thinking encryption,2014,Food / Beverages,9011 -13,8D0c29189C9798B,"Keller, Campos and Black",https://www.garner.info/,Liberia,Ameliorated directional emulation,2020,Museums / Institutions,2862 -14,D2c91cc03CA394c,Glover-Pope,http://www.silva.biz/,United Arab Emirates,Persevering contextually-based approach,2013,Medical Practice,9079 -15,C8AC1eaf9C036F4,Pacheco-Spears,https://aguilar.com/,Sweden,Secured logistical synergy,1984,Maritime,769 -16,b5D10A14f7a8AfE,Hodge-Ayers,http://www.archer-elliott.com/,Honduras,Future-proofed radical implementation,1990,Facilities Services,8508 -17,68139b5C4De03B4,"Bowers, Guerra and Krause",http://www.carrillo-nicholson.com/,Uganda,De-engineered transitional strategy,1972,Primary / Secondary Education,6986 -18,5c2EffEfdba2BdF,Mckenzie-Melton,http://montoya-thompson.com/,Hong Kong,Reverse-engineered heuristic alliance,1998,Investment Management / Hedge Fund / Private Equity,4589 -19,ba179F19F7925f5,Branch-Mann,http://www.lozano.com/,Botswana,Adaptive intangible frame,1999,Architecture / Planning,7961 -20,c1Ce9B350BAc66b,Weiss and Sons,https://barrett.com/,Korea,Sharable optimal functionalities,2011,Plastics,5984 -21,8de40AC4e6EaCa4,"Velez, Payne and Coffey",http://burton.com/,Luxembourg,Mandatory coherent synergy,1986,Wholesale,5010 -22,Aad86a4F0385F2d,Harrell LLC,http://www.frey-rosario.com/,Guadeloupe,Reverse-engineered mission-critical moratorium,2018,Construction,2185 -23,22aC3FFd64fD703,"Eaton, Reynolds and Vargas",http://www.freeman.biz/,Monaco,Self-enabling multi-tasking process improvement,2014,Luxury Goods / Jewelry,8987 -24,5Ec4C272bCf085c,Robbins-Cummings,http://donaldson-wilkins.com/,Belgium,Organic non-volatile hierarchy,1991,Pharmaceuticals,5038 -25,5fDBeA8BB91a000,Jenkins Inc,http://www.kirk.biz/,South Africa,Front-line systematic help-desk,2002,Insurance,1215 -26,dFfD6a6F9AC2d9C,"Greene, Benjamin and Novak",http://www.kent.net/,Romania,Centralized leadingedge moratorium,2012,Museums / Institutions,4941 -27,4B217cC5a0674C5,"Dickson, Richmond and Clay",http://everett.com/,Czech Republic,Team-oriented tangible complexity,1980,Real Estate / Mortgage,3122 -28,88b1f1cDcf59a37,Prince-David,http://thompson.com/,Christmas Island,Virtual holistic methodology,1970,Banking / Mortgage,1046 -29,f9F7bBCAEeC360F,Ayala LLC,http://www.zhang.com/,Philippines,Open-source zero administration hierarchy,2021,Legal Services,7664 -30,7Cb3AeFcE4Ba31e,Rivas Group,https://hebert.org/,Australia,Open-architected well-modulated capacity,1998,Logistics / Procurement,4155 -31,ccBcC32adcbc530,"Sloan, Mays and Whitehead",http://lawson.com/,Chad,Face-to-face high-level conglomeration,1997,Civil Engineering,365 -32,f5afd686b3d05F5,"Durham, Allen and Barnes",http://chan-stafford.org/,Zimbabwe,Synergistic web-enabled framework,1993,Mechanical or Industrial Engineering,6135 -33,38C6cfC5074Fa5e,Fritz-Franklin,http://www.lambert.com/,Nepal,Automated 4thgeneration website,1972,Hospitality,4516 -34,5Cd7efccCcba38f,Burch-Ewing,http://cline.net/,Taiwan,User-centric 4thgeneration system engine,1981,Venture Capital / VC,7443 -35,9E6Acb51e3F9d6F,"Glass, Barrera and Turner",https://dunlap.com/,Kyrgyz Republic,Multi-channeled 3rdgeneration open system,2020,Utilities,2610 -36,4D4d7E18321eaeC,Pineda-Cox,http://aguilar.org/,Bolivia,Fundamental asynchronous capability,2010,Human Resources / HR,1312 -37,485f5d06B938F2b,"Baker, Mccann and Macdonald",http://www.anderson-barker.com/,Kenya,Cross-group user-facing focus group,2013,Legislative Office,1638 -38,19E3a5Bf6dBDc4F,Cuevas-Moss,https://dodson-castaneda.net/,Guatemala,Extended human-resource intranet,1994,Music,9995 -39,6883A965c7b68F7,Hahn PLC,http://newman.com/,Belarus,Organic logistical leverage,2012,Electrical / Electronic Manufacturing,3715 -40,AC5B7AA74Aa4A2E,"Valentine, Ferguson and Kramer",http://stuart.net/,Jersey,Centralized secondary time-frame,1997,Non - Profit / Volunteering,3585 -41,decab0D5027CA6a,Arroyo Inc,https://www.turner.com/,Grenada,Managed demand-driven website,2006,Writing / Editing,9067 -42,dF084FbBb613eea,Walls LLC,http://www.reese-vasquez.biz/,Cape Verde,Self-enabling fresh-thinking installation,1989,Investment Management / Hedge Fund / Private Equity,1678 -43,A2D89Ab9bCcAd4e,"Mitchell, Warren and Schneider",https://fox.biz/,Trinidad and Tobago,Enhanced intangible time-frame,2021,Capital Markets / Hedge Fund / Private Equity,3816 -44,77aDc905434a49f,Prince PLC,https://www.watts.com/,Sweden,Profit-focused coherent installation,2016,Individual / Family Services,7645 -45,235fdEFE2cfDa5F,Brock-Blackwell,http://www.small.com/,Benin,Secured foreground emulation,1986,Online Publishing,7034 -46,1eD64cFe986BBbE,Walton-Barnett,https://ashley-schaefer.com/,Western Sahara,Right-sized clear-thinking flexibility,2001,Luxury Goods / Jewelry,1746 -47,CbBbFcdd0eaE2cF,Bartlett-Arroyo,https://cruz.com/,Northern Mariana Islands,Realigned didactic function,1976,Civic / Social Organization,3987 -48,49aECbDaE6aBD53,"Wallace, Madden and Morris",http://www.blevins-fernandez.biz/,Germany,Persistent real-time customer loyalty,2016,Pharmaceuticals,9443 -49,7b3fe6e7E72bFa4,Berg-Sparks,https://cisneros-love.com/,Canada,Stand-alone static implementation,1974,Arts / Crafts,2073 -50,c6DedA82A8aef7E,Gonzales Ltd,http://bird.com/,Tonga,Managed human-resource policy,1988,Consumer Goods,9069 -51,7D9FBF85cdC3871,Lawson and Sons,https://www.wong.com/,French Southern Territories,Compatible analyzing intranet,2021,Arts / Crafts,3527 -52,7dd18Fb7cB07b65,"Mcguire, Mcconnell and Olsen",https://melton-briggs.com/,Korea,Profound client-server frame,1988,Printing,8445 -53,EF5B55FadccB8Fe,Charles-Phillips,https://bowman.com/,Cote d'Ivoire,Monitored client-server implementation,2012,Mental Health Care,3450 -54,f8D4B99e11fAF5D,Odom Ltd,https://www.humphrey-hess.com/,Cote d'Ivoire,Advanced static process improvement,2012,Management Consulting,1825 -55,e24D21BFd3bF1E5,Richard PLC,https://holden-coleman.net/,Mayotte,Object-based optimizing model,1971,Broadcast Media,4942 -56,B9BdfEB6D3Ca44E,Sampson Ltd,https://blevins.com/,Cayman Islands,Intuitive local adapter,2005,Farming,1418 -57,2a74D6f3D3B268e,"Cherry, Le and Callahan",https://waller-delacruz.biz/,Nigeria,Universal human-resource collaboration,2017,Entertainment / Movie Production,7202 -58,Bf3F3f62c8aBC33,Cherry PLC,https://www.avila.info/,Marshall Islands,Persistent tertiary website,1980,Plastics,8245 -59,aeBe26B80a7a23c,Melton-Nichols,https://kennedy.com/,Palau,User-friendly clear-thinking productivity,2021,Legislative Office,8741 -60,aAeb29ad43886C6,Potter-Walsh,http://thomas-french.org/,Turkey,Optional non-volatile open system,2008,Human Resources / HR,6923 -61,bD1bc6bB6d1FeD3,Freeman-Chen,https://mathis.com/,Timor-Leste,Phased next generation adapter,1973,International Trade / Development,346 -62,EB9f456e8b7022a,Soto Group,https://norris.info/,Vietnam,Enterprise-wide executive installation,1988,Business Supplies / Equipment,9097 -63,Dfef38C51D8DAe3,"Poole, Cruz and Whitney",https://reed.info/,Reunion,Balanced analyzing groupware,1978,Marketing / Advertising / Sales,2992 -64,055ffEfB2Dd95B0,Riley Ltd,http://wiley.com/,Brazil,Optional exuding superstructure,1986,Textiles,9315 -65,cBfe4dbAE1699da,"Erickson, Andrews and Bailey",https://www.hobbs-grant.com/,Eritrea,Vision-oriented secondary project,2014,Consumer Electronics,7829 -66,fdFbecbadcdCdf1,"Wilkinson, Charles and Arroyo",http://hunter-mcfarland.com/,United States Virgin Islands,Assimilated 24/7 archive,1996,Building Materials,602 -67,5DCb8A5a5ca03c0,Floyd Ltd,http://www.whitney.com/,Falkland Islands (Malvinas),Function-based fault-tolerant concept,2017,Public Relations / PR,2911 -68,ce57DCbcFD6d618,Newman-Galloway,https://www.scott.com/,Luxembourg,Enhanced foreground collaboration,1987,Information Technology / IT,3934 -69,5aaD187dc929371,Frazier-Butler,https://www.daugherty-farley.info/,Northern Mariana Islands,Persistent interactive circuit,1972,Outsourcing / Offshoring,5130 -70,902D7Ac8b6d476b,Newton Inc,https://www.richmond-manning.info/,Netherlands Antilles,Fundamental stable info-mediaries,1976,Military Industry,563 -71,32BB9Ff4d939788,Duffy-Levy,https://www.potter.com/,Guernsey,Diverse exuding installation,1982,Wireless,6146 -72,adcB0afbE58bAe3,Wagner LLC,https://decker-esparza.com/,Uruguay,Reactive attitude-oriented toolset,1987,International Affairs,6874 -73,dfcA1c84AdB61Ac,Mccall-Holmes,http://www.dean.com/,Benin,Object-based value-added database,2009,Legal Services,696 -74,208044AC2fe52F3,Massey LLC,https://frazier.biz/,Suriname,Configurable zero administration Graphical User Interface,1986,Accounting,5004 -75,f3C365f0c1A0623,Hicks LLC,http://alvarez.biz/,Pakistan,Quality-focused client-server Graphical User Interface,1970,Computer Software / Engineering,8480 -76,ec5Bdd3CBAfaB93,"Cole, Russell and Avery",http://www.blankenship.com/,Mongolia,De-engineered fault-tolerant challenge,2000,Law Enforcement,7012 -77,DDB19Be7eeB56B4,Cummings-Rojas,https://simon-pearson.com/,Svalbard & Jan Mayen Islands,User-centric modular customer loyalty,2012,Financial Services,7529 -78,dd6CA3d0bc3cAfc,"Beasley, Greene and Mahoney",http://www.petersen-lawrence.com/,Togo,Extended content-based methodology,1976,Religious Institutions,869 -79,A0B9d56e61070e3,"Beasley, Sims and Allison",http://burke.info/,Latvia,Secured zero tolerance hub,1972,Facilities Services,6182 -80,cBa7EFe5D05Adaf,Crawford-Rivera,https://black-ramirez.org/,Cuba,Persevering exuding budgetary management,1999,Online Publishing,7805 -81,Ea3f6D52Ec73563,Montes-Hensley,https://krueger.org/,Liechtenstein,Multi-tiered secondary productivity,2009,Printing,8433 -82,bC0CEd48A8000E0,Velazquez-Odom,https://stokes.com/,Djibouti,Streamlined 6thgeneration function,2002,Alternative Dispute Resolution,4044 -83,c89b9b59BC4baa1,Eaton-Morales,https://www.reeves-graham.com/,Micronesia,Customer-focused explicit frame,1990,Capital Markets / Hedge Fund / Private Equity,7013 -84,FEC51bce8421a7b,"Roberson, Pennington and Palmer",http://www.keith-fisher.com/,Cameroon,Adaptive bi-directional hierarchy,1993,Telecommunications,5571 -85,e0E8e27eAc9CAd5,"George, Russo and Guerra",https://drake.com/,Sweden,Centralized non-volatile capability,1989,Military Industry,2880 -86,B97a6CF9bf5983C,Davila Inc,https://mcconnell.info/,Cocos (Keeling) Islands,Profit-focused dedicated frame,2017,Consumer Electronics,2215 -87,a0a6f9b3DbcBEb5,Mays-Preston,http://www.browning-key.com/,Mali,User-centric heuristic focus group,2006,Military Industry,5786 -88,8cC1bDa330a5871,Pineda-Morton,https://www.carr.com/,United States Virgin Islands,Grass-roots methodical info-mediaries,1991,Printing,6168 -89,ED889CB2FE9cbd3,Huang and Sons,https://www.bolton.com/,Eritrea,Re-contextualized dynamic hierarchy,1981,Semiconductors,7484 -90,F4Dc1417BC6cb8f,Gilbert-Simon,https://www.bradford.biz/,Burundi,Grass-roots radical parallelism,1973,Newspapers / Journalism,1927 -91,7ABc3c7ecA03B34,Sampson-Griffith,http://hendricks.org/,Benin,Multi-layered composite paradigm,1972,Textiles,3881 -92,4e0719FBE38e0aB,Miles-Dominguez,http://www.turner.com/,Gibraltar,Organized empowering forecast,1996,Civic / Social Organization,897 -93,dEbDAAeDfaed00A,Rowe and Sons,https://www.simpson.org/,El Salvador,Balanced multimedia knowledgebase,1978,Facilities Services,8172 -94,61BDeCfeFD0cEF5,"Valenzuela, Holmes and Rowland",https://www.dorsey.net/,Taiwan,Persistent tertiary focus group,1999,Transportation,1483 -95,4e91eD25f486110,"Best, Wade and Shepard",https://zimmerman.com/,Zimbabwe,Innovative background definition,1991,Gambling / Casinos,4873 -96,0a0bfFbBbB8eC7c,Holmes Group,https://mcdowell.org/,Ethiopia,Right-sized zero tolerance focus group,1975,Photography,2988 -97,BA6Cd9Dae2Efd62,Good Ltd,http://duffy.com/,Anguilla,Reverse-engineered composite moratorium,1971,Consumer Services,4292 -98,E7df80C60Abd7f9,Clements-Espinoza,http://www.flowers.net/,Falkland Islands (Malvinas),Progressive modular hub,1991,Broadcast Media,236 -99,AFc285dbE2fEd24,Mendez Inc,https://www.burke.net/,Kyrgyz Republic,User-friendly exuding migration,1993,Education Management,339 -100,e9eB5A60Cef8354,Watkins-Kaiser,http://www.herring.com/,Togo,Synergistic background access,2009,Financial Services,2785 +Index,Organization Id,Name,Website,Country,Description,Founded,Industry,Number of employees +1,FAB0d41d5b5d22c,Ferrell LLC,https://price.net/,Papua New Guinea,Horizontal empowering knowledgebase,1990,Plastics,3498 +2,6A7EdDEA9FaDC52,"Mckinney, Riley and Day",http://www.hall-buchanan.info/,Finland,User-centric system-worthy leverage,2015,Glass / Ceramics / Concrete,4952 +3,0bFED1ADAE4bcC1,Hester Ltd,http://sullivan-reed.com/,China,Switchable scalable moratorium,1971,Public Safety,5287 +4,2bFC1Be8a4ce42f,Holder-Sellers,https://becker.com/,Turkmenistan,De-engineered systemic artificial intelligence,2004,Automotive,921 +5,9eE8A6a4Eb96C24,Mayer Group,http://www.brewer.com/,Mauritius,Synchronized needs-based challenge,1991,Transportation,7870 +6,cC757116fe1C085,Henry-Thompson,http://morse.net/,Bahamas,Face-to-face well-modulated customer loyalty,1992,Primary / Secondary Education,4914 +7,219233e8aFF1BC3,Hansen-Everett,https://www.kidd.org/,Pakistan,Seamless disintermediate collaboration,2018,Publishing Industry,7832 +8,ccc93DCF81a31CD,Mcintosh-Mora,https://www.brooks.com/,Heard Island and McDonald Islands,Centralized attitude-oriented capability,1970,Import / Export,4389 +9,0B4F93aA06ED03e,Carr Inc,http://ross.com/,Kuwait,Distributed impactful customer loyalty,1996,Plastics,8167 +10,738b5aDe6B1C6A5,Gaines Inc,http://sandoval-hooper.com/,Uzbekistan,Multi-lateral scalable protocol,1997,Outsourcing / Offshoring,9698 +11,AE61b8Ffebbc476,Kidd Group,http://www.lyons.com/,Bouvet Island (Bouvetoya),Proactive foreground paradigm,2001,Primary / Secondary Education,7473 +12,eb3B7D06cCdD609,Crane-Clarke,https://www.sandoval.com/,Denmark,Front-line clear-thinking encryption,2014,Food / Beverages,9011 +13,8D0c29189C9798B,"Keller, Campos and Black",https://www.garner.info/,Liberia,Ameliorated directional emulation,2020,Museums / Institutions,2862 +14,D2c91cc03CA394c,Glover-Pope,http://www.silva.biz/,United Arab Emirates,Persevering contextually-based approach,2013,Medical Practice,9079 +15,C8AC1eaf9C036F4,Pacheco-Spears,https://aguilar.com/,Sweden,Secured logistical synergy,1984,Maritime,769 +16,b5D10A14f7a8AfE,Hodge-Ayers,http://www.archer-elliott.com/,Honduras,Future-proofed radical implementation,1990,Facilities Services,8508 +17,68139b5C4De03B4,"Bowers, Guerra and Krause",http://www.carrillo-nicholson.com/,Uganda,De-engineered transitional strategy,1972,Primary / Secondary Education,6986 +18,5c2EffEfdba2BdF,Mckenzie-Melton,http://montoya-thompson.com/,Hong Kong,Reverse-engineered heuristic alliance,1998,Investment Management / Hedge Fund / Private Equity,4589 +19,ba179F19F7925f5,Branch-Mann,http://www.lozano.com/,Botswana,Adaptive intangible frame,1999,Architecture / Planning,7961 +20,c1Ce9B350BAc66b,Weiss and Sons,https://barrett.com/,Korea,Sharable optimal functionalities,2011,Plastics,5984 +21,8de40AC4e6EaCa4,"Velez, Payne and Coffey",http://burton.com/,Luxembourg,Mandatory coherent synergy,1986,Wholesale,5010 +22,Aad86a4F0385F2d,Harrell LLC,http://www.frey-rosario.com/,Guadeloupe,Reverse-engineered mission-critical moratorium,2018,Construction,2185 +23,22aC3FFd64fD703,"Eaton, Reynolds and Vargas",http://www.freeman.biz/,Monaco,Self-enabling multi-tasking process improvement,2014,Luxury Goods / Jewelry,8987 +24,5Ec4C272bCf085c,Robbins-Cummings,http://donaldson-wilkins.com/,Belgium,Organic non-volatile hierarchy,1991,Pharmaceuticals,5038 +25,5fDBeA8BB91a000,Jenkins Inc,http://www.kirk.biz/,South Africa,Front-line systematic help-desk,2002,Insurance,1215 +26,dFfD6a6F9AC2d9C,"Greene, Benjamin and Novak",http://www.kent.net/,Romania,Centralized leadingedge moratorium,2012,Museums / Institutions,4941 +27,4B217cC5a0674C5,"Dickson, Richmond and Clay",http://everett.com/,Czech Republic,Team-oriented tangible complexity,1980,Real Estate / Mortgage,3122 +28,88b1f1cDcf59a37,Prince-David,http://thompson.com/,Christmas Island,Virtual holistic methodology,1970,Banking / Mortgage,1046 +29,f9F7bBCAEeC360F,Ayala LLC,http://www.zhang.com/,Philippines,Open-source zero administration hierarchy,2021,Legal Services,7664 +30,7Cb3AeFcE4Ba31e,Rivas Group,https://hebert.org/,Australia,Open-architected well-modulated capacity,1998,Logistics / Procurement,4155 +31,ccBcC32adcbc530,"Sloan, Mays and Whitehead",http://lawson.com/,Chad,Face-to-face high-level conglomeration,1997,Civil Engineering,365 +32,f5afd686b3d05F5,"Durham, Allen and Barnes",http://chan-stafford.org/,Zimbabwe,Synergistic web-enabled framework,1993,Mechanical or Industrial Engineering,6135 +33,38C6cfC5074Fa5e,Fritz-Franklin,http://www.lambert.com/,Nepal,Automated 4thgeneration website,1972,Hospitality,4516 +34,5Cd7efccCcba38f,Burch-Ewing,http://cline.net/,Taiwan,User-centric 4thgeneration system engine,1981,Venture Capital / VC,7443 +35,9E6Acb51e3F9d6F,"Glass, Barrera and Turner",https://dunlap.com/,Kyrgyz Republic,Multi-channeled 3rdgeneration open system,2020,Utilities,2610 +36,4D4d7E18321eaeC,Pineda-Cox,http://aguilar.org/,Bolivia,Fundamental asynchronous capability,2010,Human Resources / HR,1312 +37,485f5d06B938F2b,"Baker, Mccann and Macdonald",http://www.anderson-barker.com/,Kenya,Cross-group user-facing focus group,2013,Legislative Office,1638 +38,19E3a5Bf6dBDc4F,Cuevas-Moss,https://dodson-castaneda.net/,Guatemala,Extended human-resource intranet,1994,Music,9995 +39,6883A965c7b68F7,Hahn PLC,http://newman.com/,Belarus,Organic logistical leverage,2012,Electrical / Electronic Manufacturing,3715 +40,AC5B7AA74Aa4A2E,"Valentine, Ferguson and Kramer",http://stuart.net/,Jersey,Centralized secondary time-frame,1997,Non - Profit / Volunteering,3585 +41,decab0D5027CA6a,Arroyo Inc,https://www.turner.com/,Grenada,Managed demand-driven website,2006,Writing / Editing,9067 +42,dF084FbBb613eea,Walls LLC,http://www.reese-vasquez.biz/,Cape Verde,Self-enabling fresh-thinking installation,1989,Investment Management / Hedge Fund / Private Equity,1678 +43,A2D89Ab9bCcAd4e,"Mitchell, Warren and Schneider",https://fox.biz/,Trinidad and Tobago,Enhanced intangible time-frame,2021,Capital Markets / Hedge Fund / Private Equity,3816 +44,77aDc905434a49f,Prince PLC,https://www.watts.com/,Sweden,Profit-focused coherent installation,2016,Individual / Family Services,7645 +45,235fdEFE2cfDa5F,Brock-Blackwell,http://www.small.com/,Benin,Secured foreground emulation,1986,Online Publishing,7034 +46,1eD64cFe986BBbE,Walton-Barnett,https://ashley-schaefer.com/,Western Sahara,Right-sized clear-thinking flexibility,2001,Luxury Goods / Jewelry,1746 +47,CbBbFcdd0eaE2cF,Bartlett-Arroyo,https://cruz.com/,Northern Mariana Islands,Realigned didactic function,1976,Civic / Social Organization,3987 +48,49aECbDaE6aBD53,"Wallace, Madden and Morris",http://www.blevins-fernandez.biz/,Germany,Persistent real-time customer loyalty,2016,Pharmaceuticals,9443 +49,7b3fe6e7E72bFa4,Berg-Sparks,https://cisneros-love.com/,Canada,Stand-alone static implementation,1974,Arts / Crafts,2073 +50,c6DedA82A8aef7E,Gonzales Ltd,http://bird.com/,Tonga,Managed human-resource policy,1988,Consumer Goods,9069 +51,7D9FBF85cdC3871,Lawson and Sons,https://www.wong.com/,French Southern Territories,Compatible analyzing intranet,2021,Arts / Crafts,3527 +52,7dd18Fb7cB07b65,"Mcguire, Mcconnell and Olsen",https://melton-briggs.com/,Korea,Profound client-server frame,1988,Printing,8445 +53,EF5B55FadccB8Fe,Charles-Phillips,https://bowman.com/,Cote d'Ivoire,Monitored client-server implementation,2012,Mental Health Care,3450 +54,f8D4B99e11fAF5D,Odom Ltd,https://www.humphrey-hess.com/,Cote d'Ivoire,Advanced static process improvement,2012,Management Consulting,1825 +55,e24D21BFd3bF1E5,Richard PLC,https://holden-coleman.net/,Mayotte,Object-based optimizing model,1971,Broadcast Media,4942 +56,B9BdfEB6D3Ca44E,Sampson Ltd,https://blevins.com/,Cayman Islands,Intuitive local adapter,2005,Farming,1418 +57,2a74D6f3D3B268e,"Cherry, Le and Callahan",https://waller-delacruz.biz/,Nigeria,Universal human-resource collaboration,2017,Entertainment / Movie Production,7202 +58,Bf3F3f62c8aBC33,Cherry PLC,https://www.avila.info/,Marshall Islands,Persistent tertiary website,1980,Plastics,8245 +59,aeBe26B80a7a23c,Melton-Nichols,https://kennedy.com/,Palau,User-friendly clear-thinking productivity,2021,Legislative Office,8741 +60,aAeb29ad43886C6,Potter-Walsh,http://thomas-french.org/,Turkey,Optional non-volatile open system,2008,Human Resources / HR,6923 +61,bD1bc6bB6d1FeD3,Freeman-Chen,https://mathis.com/,Timor-Leste,Phased next generation adapter,1973,International Trade / Development,346 +62,EB9f456e8b7022a,Soto Group,https://norris.info/,Vietnam,Enterprise-wide executive installation,1988,Business Supplies / Equipment,9097 +63,Dfef38C51D8DAe3,"Poole, Cruz and Whitney",https://reed.info/,Reunion,Balanced analyzing groupware,1978,Marketing / Advertising / Sales,2992 +64,055ffEfB2Dd95B0,Riley Ltd,http://wiley.com/,Brazil,Optional exuding superstructure,1986,Textiles,9315 +65,cBfe4dbAE1699da,"Erickson, Andrews and Bailey",https://www.hobbs-grant.com/,Eritrea,Vision-oriented secondary project,2014,Consumer Electronics,7829 +66,fdFbecbadcdCdf1,"Wilkinson, Charles and Arroyo",http://hunter-mcfarland.com/,United States Virgin Islands,Assimilated 24/7 archive,1996,Building Materials,602 +67,5DCb8A5a5ca03c0,Floyd Ltd,http://www.whitney.com/,Falkland Islands (Malvinas),Function-based fault-tolerant concept,2017,Public Relations / PR,2911 +68,ce57DCbcFD6d618,Newman-Galloway,https://www.scott.com/,Luxembourg,Enhanced foreground collaboration,1987,Information Technology / IT,3934 +69,5aaD187dc929371,Frazier-Butler,https://www.daugherty-farley.info/,Northern Mariana Islands,Persistent interactive circuit,1972,Outsourcing / Offshoring,5130 +70,902D7Ac8b6d476b,Newton Inc,https://www.richmond-manning.info/,Netherlands Antilles,Fundamental stable info-mediaries,1976,Military Industry,563 +71,32BB9Ff4d939788,Duffy-Levy,https://www.potter.com/,Guernsey,Diverse exuding installation,1982,Wireless,6146 +72,adcB0afbE58bAe3,Wagner LLC,https://decker-esparza.com/,Uruguay,Reactive attitude-oriented toolset,1987,International Affairs,6874 +73,dfcA1c84AdB61Ac,Mccall-Holmes,http://www.dean.com/,Benin,Object-based value-added database,2009,Legal Services,696 +74,208044AC2fe52F3,Massey LLC,https://frazier.biz/,Suriname,Configurable zero administration Graphical User Interface,1986,Accounting,5004 +75,f3C365f0c1A0623,Hicks LLC,http://alvarez.biz/,Pakistan,Quality-focused client-server Graphical User Interface,1970,Computer Software / Engineering,8480 +76,ec5Bdd3CBAfaB93,"Cole, Russell and Avery",http://www.blankenship.com/,Mongolia,De-engineered fault-tolerant challenge,2000,Law Enforcement,7012 +77,DDB19Be7eeB56B4,Cummings-Rojas,https://simon-pearson.com/,Svalbard & Jan Mayen Islands,User-centric modular customer loyalty,2012,Financial Services,7529 +78,dd6CA3d0bc3cAfc,"Beasley, Greene and Mahoney",http://www.petersen-lawrence.com/,Togo,Extended content-based methodology,1976,Religious Institutions,869 +79,A0B9d56e61070e3,"Beasley, Sims and Allison",http://burke.info/,Latvia,Secured zero tolerance hub,1972,Facilities Services,6182 +80,cBa7EFe5D05Adaf,Crawford-Rivera,https://black-ramirez.org/,Cuba,Persevering exuding budgetary management,1999,Online Publishing,7805 +81,Ea3f6D52Ec73563,Montes-Hensley,https://krueger.org/,Liechtenstein,Multi-tiered secondary productivity,2009,Printing,8433 +82,bC0CEd48A8000E0,Velazquez-Odom,https://stokes.com/,Djibouti,Streamlined 6thgeneration function,2002,Alternative Dispute Resolution,4044 +83,c89b9b59BC4baa1,Eaton-Morales,https://www.reeves-graham.com/,Micronesia,Customer-focused explicit frame,1990,Capital Markets / Hedge Fund / Private Equity,7013 +84,FEC51bce8421a7b,"Roberson, Pennington and Palmer",http://www.keith-fisher.com/,Cameroon,Adaptive bi-directional hierarchy,1993,Telecommunications,5571 +85,e0E8e27eAc9CAd5,"George, Russo and Guerra",https://drake.com/,Sweden,Centralized non-volatile capability,1989,Military Industry,2880 +86,B97a6CF9bf5983C,Davila Inc,https://mcconnell.info/,Cocos (Keeling) Islands,Profit-focused dedicated frame,2017,Consumer Electronics,2215 +87,a0a6f9b3DbcBEb5,Mays-Preston,http://www.browning-key.com/,Mali,User-centric heuristic focus group,2006,Military Industry,5786 +88,8cC1bDa330a5871,Pineda-Morton,https://www.carr.com/,United States Virgin Islands,Grass-roots methodical info-mediaries,1991,Printing,6168 +89,ED889CB2FE9cbd3,Huang and Sons,https://www.bolton.com/,Eritrea,Re-contextualized dynamic hierarchy,1981,Semiconductors,7484 +90,F4Dc1417BC6cb8f,Gilbert-Simon,https://www.bradford.biz/,Burundi,Grass-roots radical parallelism,1973,Newspapers / Journalism,1927 +91,7ABc3c7ecA03B34,Sampson-Griffith,http://hendricks.org/,Benin,Multi-layered composite paradigm,1972,Textiles,3881 +92,4e0719FBE38e0aB,Miles-Dominguez,http://www.turner.com/,Gibraltar,Organized empowering forecast,1996,Civic / Social Organization,897 +93,dEbDAAeDfaed00A,Rowe and Sons,https://www.simpson.org/,El Salvador,Balanced multimedia knowledgebase,1978,Facilities Services,8172 +94,61BDeCfeFD0cEF5,"Valenzuela, Holmes and Rowland",https://www.dorsey.net/,Taiwan,Persistent tertiary focus group,1999,Transportation,1483 +95,4e91eD25f486110,"Best, Wade and Shepard",https://zimmerman.com/,Zimbabwe,Innovative background definition,1991,Gambling / Casinos,4873 +96,0a0bfFbBbB8eC7c,Holmes Group,https://mcdowell.org/,Ethiopia,Right-sized zero tolerance focus group,1975,Photography,2988 +97,BA6Cd9Dae2Efd62,Good Ltd,http://duffy.com/,Anguilla,Reverse-engineered composite moratorium,1971,Consumer Services,4292 +98,E7df80C60Abd7f9,Clements-Espinoza,http://www.flowers.net/,Falkland Islands (Malvinas),Progressive modular hub,1991,Broadcast Media,236 +99,AFc285dbE2fEd24,Mendez Inc,https://www.burke.net/,Kyrgyz Republic,User-friendly exuding migration,1993,Education Management,339 +100,e9eB5A60Cef8354,Watkins-Kaiser,http://www.herring.com/,Togo,Synergistic background access,2009,Financial Services,2785 diff --git a/applications/ColossalQA/data/tests/test.html b/applications/ColossalQA/data/tests/test.html index 5ad21421d..6152ffe29 100644 --- a/applications/ColossalQA/data/tests/test.html +++ b/applications/ColossalQA/data/tests/test.html @@ -1,7 +1,7 @@ - + @@ -16,18 +16,18 @@ logging — Logging facility for Python — Python 3.11.5 documentation - + - + - + - + @@ -36,11 +36,11 @@ - - - - + + + +
                  \ No newline at end of file +
                  diff --git a/applications/ColossalQA/data/tests/test.md b/applications/ColossalQA/data/tests/test.md index 20d3c612f..b6e94792c 100644 --- a/applications/ColossalQA/data/tests/test.md +++ b/applications/ColossalQA/data/tests/test.md @@ -34,9 +34,9 @@ python api_server.py --host localhost --port $PORT_NUMBER --model $PATH_TO_MODEL ### Collect your data For ChatGPT based Agent we support document retrieval and simple sql search. -If you want to run the demo locally, we provided document retrieval based conversation system built upon langchain. It accept a wide range of documents. +If you want to run the demo locally, we provided document retrieval based conversation system built upon langchain. It accept a wide range of documents. -Read comments under ./colossalqa/data_loader for more detail +Read comments under ./colossalqa/data_loader for more detail ### Serving Currently use vllm will replace with colossal inference when ready. Please refer class VllmLLM. diff --git a/applications/ColossalQA/data/tests/test.txt b/applications/ColossalQA/data/tests/test.txt index ff5bf2dc7..e608c772e 100644 --- a/applications/ColossalQA/data/tests/test.txt +++ b/applications/ColossalQA/data/tests/test.txt @@ -1,38 +1,38 @@ -Your Name -Lorem ipsum dolor sit amet, consectetuer adipiscing elit - 123 Your Street -Your City, ST 12345 -(123) 456-7890 -no_reply@example.com - EXPERIENCE -Company, Location — Job Title -MONTH 20XX - PRESENT -Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh. -Company, Location — Job Title -MONTH 20XX - MONTH 20XX -Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh. -Company, Location — Job Title -MONTH 20XX - MONTH 20XX -Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh. -EDUCATION -School Name, Location — Degree -MONTH 20XX - MONTH 20XX -Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt ut laoreet dolore. -School Name, Location — Degree -MONTH 20XX - MONTH 20XX -Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam. -PROJECTS -Project Name — Detail -Lorem ipsum dolor sit amet, consectetuer adipiscing elit. - SKILLS -* Lorem ipsum dolor sit amet. -* Consectetuer adipiscing elit. -* Sed diam nonummy nibh euismod tincidunt. -* L​​​‌​aoreet dolore magna aliquam erat volutpat. -AWARDS -Lorem ipsum dolor sit amet Consectetuer adipiscing elit, Sed diam nonummy -Nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat. -Lorem ipsum dolor sit amet Consectetuer adipiscing elit, Sed diam nonummy -Nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat. -LANGUAGES -Lorem ipsum, Dolor sit amet, Consectetuer \ No newline at end of file +Your Name +Lorem ipsum dolor sit amet, consectetuer adipiscing elit + 123 Your Street +Your City, ST 12345 +(123) 456-7890 +no_reply@example.com + EXPERIENCE +Company, Location — Job Title +MONTH 20XX - PRESENT +Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh. +Company, Location — Job Title +MONTH 20XX - MONTH 20XX +Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh. +Company, Location — Job Title +MONTH 20XX - MONTH 20XX +Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh. +EDUCATION +School Name, Location — Degree +MONTH 20XX - MONTH 20XX +Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt ut laoreet dolore. +School Name, Location — Degree +MONTH 20XX - MONTH 20XX +Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam. +PROJECTS +Project Name — Detail +Lorem ipsum dolor sit amet, consectetuer adipiscing elit. + SKILLS +* Lorem ipsum dolor sit amet. +* Consectetuer adipiscing elit. +* Sed diam nonummy nibh euismod tincidunt. +* L​​​‌​aoreet dolore magna aliquam erat volutpat. +AWARDS +Lorem ipsum dolor sit amet Consectetuer adipiscing elit, Sed diam nonummy +Nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat. +Lorem ipsum dolor sit amet Consectetuer adipiscing elit, Sed diam nonummy +Nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat. +LANGUAGES +Lorem ipsum, Dolor sit amet, Consectetuer diff --git a/applications/ColossalQA/examples/retrieval_conversation_universal.py b/applications/ColossalQA/examples/retrieval_conversation_universal.py index 5d13a63c3..8999fbabd 100644 --- a/applications/ColossalQA/examples/retrieval_conversation_universal.py +++ b/applications/ColossalQA/examples/retrieval_conversation_universal.py @@ -1,22 +1,27 @@ import argparse + from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation -if __name__ == '__main__': +if __name__ == "__main__": # Parse arguments parser = argparse.ArgumentParser() - parser.add_argument('--en_model_path', type=str, default=None) - parser.add_argument('--zh_model_path', type=str, default=None) - parser.add_argument('--zh_model_name', type=str, default=None) - parser.add_argument('--en_model_name', type=str, default=None) - parser.add_argument('--sql_file_path', type=str, default=None, help='path to the a empty folder for storing sql files for indexing') + parser.add_argument("--en_model_path", type=str, default=None) + parser.add_argument("--zh_model_path", type=str, default=None) + parser.add_argument("--zh_model_name", type=str, default=None) + parser.add_argument("--en_model_name", type=str, default=None) + parser.add_argument( + "--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing" + ) args = parser.parse_args() - + # Will ask for documents path in running time - session = UniversalRetrievalConversation(files_en=None, - files_zh=None, - zh_model_path=args.zh_model_path, en_model_path=args.en_model_path, - zh_model_name=args.zh_model_name, en_model_name=args.en_model_name, - sql_file_path=args.sql_file_path - ) + session = UniversalRetrievalConversation( + files_en=None, + files_zh=None, + zh_model_path=args.zh_model_path, + en_model_path=args.en_model_path, + zh_model_name=args.zh_model_name, + en_model_name=args.en_model_name, + sql_file_path=args.sql_file_path, + ) session.start_test_session() - \ No newline at end of file diff --git a/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py b/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py index 526328dda..47897b538 100644 --- a/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py +++ b/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py @@ -5,13 +5,7 @@ from colossalqa.chain.retrieval_qa.base import RetrievalQA from colossalqa.data_loader.document_loader import DocumentLoader from colossalqa.memory import ConversationBufferWithSummary from colossalqa.mylogging import get_logger -from colossalqa.prompt.prompt import ( - PROMPT_DISAMBIGUATE_ZH, - PROMPT_RETRIEVAL_QA_ZH, - SUMMARY_PROMPT_ZH, - ZH_RETRIEVAL_QA_REJECTION_ANSWER, - ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS, -) +from colossalqa.prompt.prompt import ZH_RETRIEVAL_QA_REJECTION_ANSWER, ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS from colossalqa.retriever import CustomRetriever from langchain import LLMChain from langchain.embeddings import HuggingFaceEmbeddings @@ -116,13 +110,13 @@ class RAG_ChatBot: def split_docs(self, documents): doc_splits = self.text_splitter.split_documents(documents) return doc_splits - + def clear_docs(self, **kwargs): self.documents = [] self.docs_names = [] self.info_retriever.clear_documents() self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever) - + def reset_config(self, rag_config): self.rag_config = rag_config self.set_embed_model(**self.rag_config["embed"]) diff --git a/applications/ColossalQA/examples/webui_demo/README.md b/applications/ColossalQA/examples/webui_demo/README.md index 3e1a8adf5..61071a1c8 100644 --- a/applications/ColossalQA/examples/webui_demo/README.md +++ b/applications/ColossalQA/examples/webui_demo/README.md @@ -115,4 +115,4 @@ python webui.py --http_host "your-backend-api-host" --http_port "your-backend-ap After launching the script, you can upload files and engage with the chatbot through your web browser. -![ColossalQA Demo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/new_ui.png) \ No newline at end of file +![ColossalQA Demo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/new_ui.png) diff --git a/applications/ColossalQA/examples/webui_demo/config.py b/applications/ColossalQA/examples/webui_demo/config.py index ef90fab62..5bdc08a09 100644 --- a/applications/ColossalQA/examples/webui_demo/config.py +++ b/applications/ColossalQA/examples/webui_demo/config.py @@ -1,58 +1,30 @@ -from colossalqa.prompt.prompt import ( - PROMPT_DISAMBIGUATE_ZH, - PROMPT_RETRIEVAL_QA_ZH, - SUMMARY_PROMPT_ZH, - ZH_RETRIEVAL_QA_REJECTION_ANSWER, - ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS, -) +from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_ZH, PROMPT_RETRIEVAL_QA_ZH, SUMMARY_PROMPT_ZH from colossalqa.text_splitter import ChineseTextSplitter ALL_CONFIG = { "embed": { "embed_name": "m3e", # embedding model name "embed_model_name_or_path": "moka-ai/m3e-base", # path to embedding model, could be a local path or a huggingface path - "embed_model_device": { - "device": "cpu" - } + "embed_model_device": {"device": "cpu"}, }, "model": { "mode": "api", # "local" for loading models, "api" for using model api "model_name": "chatgpt_api", # local model name, "chatgpt_api" or "pangu_api" - "model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api - "device": { - "device": "cuda" - } - }, - "splitter": { - "name": ChineseTextSplitter - }, - "retrieval": { - "retri_top_k": 3, - "retri_kb_file_path": "./", # path to store database files - "verbose": True + "model_path": "", # path to the model, could be a local path or a huggingface path. don't need if using an api + "device": {"device": "cuda"}, }, + "splitter": {"name": ChineseTextSplitter}, + "retrieval": {"retri_top_k": 3, "retri_kb_file_path": "./", "verbose": True}, # path to store database files "chain": { "mem_summary_prompt": SUMMARY_PROMPT_ZH, # summary prompt template "mem_human_prefix": "用户", "mem_ai_prefix": "Assistant", "mem_max_tokens": 2000, - "mem_llm_kwargs": { - "max_new_tokens": 50, - "temperature": 1, - "do_sample": True - }, + "mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True}, "disambig_prompt": PROMPT_DISAMBIGUATE_ZH, # disambiguate prompt template - "disambig_llm_kwargs": { - "max_new_tokens": 30, - "temperature": 1, - "do_sample": True - }, - "gen_llm_kwargs": { - "max_new_tokens": 100, - "temperature": 1, - "do_sample": True - }, + "disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True}, + "gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True}, "gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH, # generation prompt template - "verbose": True - } -} \ No newline at end of file + "verbose": True, + }, +} diff --git a/applications/ColossalQA/examples/webui_demo/server.py b/applications/ColossalQA/examples/webui_demo/server.py index 3b0f82845..1f699421d 100644 --- a/applications/ColossalQA/examples/webui_demo/server.py +++ b/applications/ColossalQA/examples/webui_demo/server.py @@ -1,27 +1,18 @@ import argparse -import os from typing import List, Union - -from colossalqa.local.llm import ColossalAPI, ColossalLLM -from colossalqa.data_loader.document_loader import DocumentLoader -from colossalqa.mylogging import get_logger -from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation -from colossalqa.retriever import CustomRetriever -from enum import Enum -from fastapi import FastAPI, Request -from langchain.embeddings import HuggingFaceEmbeddings -from langchain.text_splitter import RecursiveCharacterTextSplitter -from pydantic import BaseModel, Field -import uvicorn - import config +import uvicorn +from colossalqa.local.llm import ColossalAPI, ColossalLLM +from colossalqa.mylogging import get_logger +from fastapi import FastAPI, Request +from pydantic import BaseModel from RAG_ChatBot import RAG_ChatBot from utils import DocAction - logger = get_logger() + def parseArgs(): parser = argparse.ArgumentParser() parser.add_argument("--http_host", default="0.0.0.0") @@ -36,6 +27,7 @@ class DocUpdateReq(BaseModel): doc_files: Union[List[str], str, None] = None action: DocAction = DocAction.ADD + class GenerationTaskReq(BaseModel): user_input: str @@ -45,7 +37,7 @@ def update_docs(data: DocUpdateReq, request: Request): if data.action == "add": if isinstance(data.doc_files, str): data.doc_files = [data.doc_files] - chatbot.load_doc_from_files(files = data.doc_files) + chatbot.load_doc_from_files(files=data.doc_files) all_docs = "" for doc in chatbot.docs_names: all_docs += f"\t{doc}\n\n" @@ -79,17 +71,18 @@ if __name__ == "__main__": elif all_config["model"]["mode"] == "api": if model_name == "pangu_api": from colossalqa.local.pangu_llm import Pangu - + gen_config = { "user": "User", "max_tokens": all_config["chain"]["disambig_llm_kwargs"]["max_new_tokens"], "temperature": all_config["chain"]["disambig_llm_kwargs"]["temperature"], - "n": 1 # the number of responses generated + "n": 1, # the number of responses generated } llm = Pangu(gen_config=gen_config) llm.set_auth_config() # verify user's auth info here elif model_name == "chatgpt_api": from langchain.llms import OpenAI + llm = OpenAI() else: raise ValueError("Unsupported mode.") diff --git a/applications/ColossalQA/examples/webui_demo/webui.py b/applications/ColossalQA/examples/webui_demo/webui.py index cd3b5fd5d..1e3433061 100644 --- a/applications/ColossalQA/examples/webui_demo/webui.py +++ b/applications/ColossalQA/examples/webui_demo/webui.py @@ -1,24 +1,26 @@ import argparse import json import os -import requests import gradio as gr - +import requests from utils import DocAction + def parseArgs(): parser = argparse.ArgumentParser() parser.add_argument("--http_host", default="0.0.0.0") parser.add_argument("--http_port", type=int, default=13666) return parser.parse_args() + def get_response(data, url): headers = {"Content-type": "application/json"} response = requests.post(url, json=data, headers=headers) response = json.loads(response.content) return response + def add_text(history, text): history = history + [(text, None)] return history, gr.update(value=None, interactive=True) @@ -28,35 +30,28 @@ def add_file(history, files): files_string = "\n".join([os.path.basename(file.name) for file in files]) doc_files = [file.name for file in files] - data = { - "doc_files": doc_files, - "action": DocAction.ADD - } + data = {"doc_files": doc_files, "action": DocAction.ADD} response = get_response(data, update_url)["response"] history = history + [(files_string, response)] return history -def bot(history): - data = { - "user_input": history[-1][0].strip() - } + +def bot(history): + data = {"user_input": history[-1][0].strip()} response = get_response(data, gen_url) if response["error"] != "": raise gr.Error(response["error"]) - + history[-1][1] = response["response"] yield history def restart(chatbot, txt): # Reset the conversation state and clear the chat history - data = { - "doc_files": "", - "action": DocAction.CLEAR - } - response = get_response(data, update_url) - + data = {"doc_files": "", "action": DocAction.CLEAR} + get_response(data, update_url) + return gr.update(value=None), gr.update(value=None, interactive=True) @@ -97,7 +92,7 @@ with gr.Blocks(css=CSS) as demo: txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(bot, chatbot, chatbot) # Clear the original textbox - txt_msg.then(lambda: gr.update(value=None, interactive=True), None, [txt], queue=False) + txt_msg.then(lambda: gr.update(value=None, interactive=True), None, [txt], queue=False) # Click Upload Button: 1. upload files 2. send config to backend, initalize model 3. get response "conversation_ready" = True/False file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False) diff --git a/applications/ColossalQA/pytest.ini b/applications/ColossalQA/pytest.ini index 9e84349f2..8bd925076 100644 --- a/applications/ColossalQA/pytest.ini +++ b/applications/ColossalQA/pytest.ini @@ -1,4 +1,4 @@ [pytest] markers = dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) - largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) \ No newline at end of file + largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) diff --git a/applications/ColossalQA/tests/test_document_loader.py b/applications/ColossalQA/tests/test_document_loader.py index 163b0d2cc..96c74f128 100644 --- a/applications/ColossalQA/tests/test_document_loader.py +++ b/applications/ColossalQA/tests/test_document_loader.py @@ -1,21 +1,21 @@ import os + from colossalqa.data_loader.document_loader import DocumentLoader def test_add_document(): - PATH = os.environ.get('TEST_DOCUMENT_LOADER_DATA_PATH') - files = [[PATH, 'all data']] + PATH = os.environ.get("TEST_DOCUMENT_LOADER_DATA_PATH") + files = [[PATH, "all data"]] document_loader = DocumentLoader(files) documents = document_loader.all_data all_files = [] for doc in documents: - assert isinstance(doc.page_content, str)==True - if doc.metadata['source'] not in all_files: - all_files.append(doc.metadata['source']) + assert isinstance(doc.page_content, str) == True + if doc.metadata["source"] not in all_files: + all_files.append(doc.metadata["source"]) print(all_files) assert len(all_files) == 6 -if __name__=='__main__': +if __name__ == "__main__": test_add_document() - diff --git a/applications/ColossalQA/tests/test_retrieval_qa.py b/applications/ColossalQA/tests/test_retrieval_qa.py index 76867b562..8ba84cc73 100644 --- a/applications/ColossalQA/tests/test_retrieval_qa.py +++ b/applications/ColossalQA/tests/test_retrieval_qa.py @@ -4,56 +4,44 @@ from colossalqa.retrieval_conversation_universal import UniversalRetrievalConver def test_en_retrievalQA(): - data_path_en = os.environ.get('TEST_DATA_PATH_EN') - data_path_zh = os.environ.get('TEST_DATA_PATH_ZH') - en_model_path = os.environ.get('EN_MODEL_PATH') - zh_model_path = os.environ.get('ZH_MODEL_PATH') - zh_model_name = os.environ.get('ZH_MODEL_NAME') - en_model_name = os.environ.get('EN_MODEL_NAME') - sql_file_path = os.environ.get('SQL_FILE_PATH') - qa_session = UniversalRetrievalConversation(files_en=[{ - 'data_path': data_path_en, - 'name': 'company information', - 'separator': '\n' - }], - files_zh=[{ - 'data_path': data_path_zh, - 'name': 'company information', - 'separator': '\n' - }], - zh_model_path=zh_model_path, - en_model_path=en_model_path, - zh_model_name=zh_model_name, - en_model_name=en_model_name, - sql_file_path=sql_file_path) - ans = qa_session.run("which company runs business in hotel industry?", which_language='en') + data_path_en = os.environ.get("TEST_DATA_PATH_EN") + data_path_zh = os.environ.get("TEST_DATA_PATH_ZH") + en_model_path = os.environ.get("EN_MODEL_PATH") + zh_model_path = os.environ.get("ZH_MODEL_PATH") + zh_model_name = os.environ.get("ZH_MODEL_NAME") + en_model_name = os.environ.get("EN_MODEL_NAME") + sql_file_path = os.environ.get("SQL_FILE_PATH") + qa_session = UniversalRetrievalConversation( + files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}], + files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}], + zh_model_path=zh_model_path, + en_model_path=en_model_path, + zh_model_name=zh_model_name, + en_model_name=en_model_name, + sql_file_path=sql_file_path, + ) + ans = qa_session.run("which company runs business in hotel industry?", which_language="en") print(ans) def test_zh_retrievalQA(): - data_path_en = os.environ.get('TEST_DATA_PATH_EN') - data_path_zh = os.environ.get('TEST_DATA_PATH_ZH') - en_model_path = os.environ.get('EN_MODEL_PATH') - zh_model_path = os.environ.get('ZH_MODEL_PATH') - zh_model_name = os.environ.get('ZH_MODEL_NAME') - en_model_name = os.environ.get('EN_MODEL_NAME') - sql_file_path = os.environ.get('SQL_FILE_PATH') - qa_session = UniversalRetrievalConversation(files_en=[{ - 'data_path': data_path_en, - 'name': 'company information', - 'separator': '\n' - }], - files_zh=[{ - 'data_path': data_path_zh, - 'name': 'company information', - 'separator': '\n' - }], - zh_model_path=zh_model_path, - en_model_path=en_model_path, - zh_model_name=zh_model_name, - en_model_name=en_model_name, - sql_file_path=sql_file_path) - ans = qa_session.run("哪家公司在经营酒店业务?", which_language='zh') + data_path_en = os.environ.get("TEST_DATA_PATH_EN") + data_path_zh = os.environ.get("TEST_DATA_PATH_ZH") + en_model_path = os.environ.get("EN_MODEL_PATH") + zh_model_path = os.environ.get("ZH_MODEL_PATH") + zh_model_name = os.environ.get("ZH_MODEL_NAME") + en_model_name = os.environ.get("EN_MODEL_NAME") + sql_file_path = os.environ.get("SQL_FILE_PATH") + qa_session = UniversalRetrievalConversation( + files_en=[{"data_path": data_path_en, "name": "company information", "separator": "\n"}], + files_zh=[{"data_path": data_path_zh, "name": "company information", "separator": "\n"}], + zh_model_path=zh_model_path, + en_model_path=en_model_path, + zh_model_name=zh_model_name, + en_model_name=en_model_name, + sql_file_path=sql_file_path, + ) + ans = qa_session.run("哪家公司在经营酒店业务?", which_language="zh") print(ans) diff --git a/applications/ColossalQA/version.txt b/applications/ColossalQA/version.txt index 8a9ecc2ea..8acdd82b7 100644 --- a/applications/ColossalQA/version.txt +++ b/applications/ColossalQA/version.txt @@ -1 +1 @@ -0.0.1 \ No newline at end of file +0.0.1 diff --git a/colossalai/__init__.py b/colossalai/__init__.py index 6b7f5d055..beef0561c 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,5 +1,5 @@ -from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch from . import accelerator +from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch try: # .version will be created by setup.py diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 5445b4a63..0aa0caa9a 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -27,7 +27,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils, CheckpointIndexFile +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -93,9 +93,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): Path(checkpoint_path).mkdir(parents=True, exist_ok=True) with FSDP.state_dict_type( - model.unwrap(), - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) ): state_dict = model.unwrap().state_dict() @@ -172,7 +170,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): with FSDP.state_dict_type( optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): fsdp_optim_state = FSDP.full_optim_state_dict( optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True @@ -241,7 +239,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): ) optimizer.load_state_dict(fsdp_state) - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ Save model to checkpoint but only on master process. diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index e1800f29b..2a1d4de9b 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -294,6 +294,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Helper functions for saving state dict # ====================================== + def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: """ Save state dict to checkpoint. @@ -305,7 +306,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors """ # Move all tensors in the state_dict to CPU before saving to avoid serialization issues state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict) - + if use_safetensors: assert is_safetensors_available(), "safetensors is not available." assert checkpoint_file_path.endswith( diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index ccf122695..43095af50 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -174,16 +174,20 @@ class ProcessGroupMesh: List[Tuple[int, ...]]: Coordinates along the axis. """ if isinstance(axis, int): - axis = [axis,] + axis = [ + axis, + ] assert isinstance(indices_at_axis[0], int) - indices_at_axis = [indices_at_axis,] + indices_at_axis = [ + indices_at_axis, + ] def add_index(base_coord, axis, indices_at_axis): coords_in_group = [] for idx in indices_at_axis: coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) return coords_in_group - + coords_in_group = [base_coord] for ax, indices_at_ax in zip(axis, indices_at_axis): new_coords_in_group = [] @@ -194,7 +198,10 @@ class ProcessGroupMesh: return coords_in_group def create_group_along_axis( - self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None + self, + axis: Union[int, List[int]], + indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, + backend: Optional[str] = None, ) -> ProcessGroup: """Create all process groups along the given axis, and return the one which the current process belongs to. @@ -207,11 +214,15 @@ class ProcessGroupMesh: ProcessGroup: The process group along the given axis which the current process belongs to. """ if isinstance(axis, int): - axis = [axis,] + axis = [ + axis, + ] if indices_at_axis is not None: assert isinstance(indices_at_axis[0], int) - indices_at_axis = [indices_at_axis,] - + indices_at_axis = [ + indices_at_axis, + ] + indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis] reduced_shape = list(self._shape) # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py index b7bc94d0e..a7efb4026 100644 --- a/colossalai/inference/engine/modeling/llama.py +++ b/colossalai/inference/engine/modeling/llama.py @@ -29,13 +29,17 @@ except: try: from colossalai.kernel.triton.flash_decoding import token_flash_decoding + HAS_TRITON_FLASH_DECODING_KERNEL = True except: - print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") + print( + "no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8" + ) HAS_TRITON_FLASH_DECODING_KERNEL = False - + try: from flash_attn import flash_attn_with_kvcache + HAS_FLASH_KERNEL = True except: HAS_FLASH_KERNEL = False @@ -48,6 +52,7 @@ def rotate_half(x): x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] @@ -96,17 +101,22 @@ def llama_triton_context_attention( infer_state.max_len_in_batch, ) -def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1): + +def llama_triton_token_attention( + query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num=-1, head_dim=-1 +): if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1: - token_flash_decoding(q = query_states, - o_tensor = attn_output, - infer_state = infer_state, - q_head_num = q_head_num, - head_dim = head_dim, - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]) - return - + token_flash_decoding( + q=query_states, + o_tensor=attn_output, + infer_state=infer_state, + q_head_num=q_head_num, + head_dim=head_dim, + cache_k=infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + cache_v=infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + ) + return + if num_key_value_groups == 1: token_attention_fwd( query_states, @@ -459,14 +469,15 @@ class LlamaInferenceForwards: ) if HAS_LIGHTLLM_KERNEL: - attn_output = torch.empty_like(query_states) - llama_triton_token_attention(query_states = query_states, - attn_output = attn_output, - infer_state = infer_state, - num_key_value_groups = self.num_key_value_groups, - q_head_num = q_len * self.num_heads, - head_dim = self.head_dim) + llama_triton_token_attention( + query_states=query_states, + attn_output=attn_output, + infer_state=infer_state, + num_key_value_groups=self.num_key_value_groups, + q_head_num=q_len * self.num_heads, + head_dim=self.head_dim, + ) else: self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py index ca12c34ed..36339ac88 100644 --- a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py @@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False class CaiQuantLinear(nn.Module): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): super().__init__() if bits not in [2, 4, 8]: @@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module): self.maxq = 2**self.bits - 1 self.groupsize = groupsize if groupsize != -1 else infeatures - self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) self.register_buffer( - 'qzeros', - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) - self.register_buffer('scales', - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + "qzeros", + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32), + ) + self.register_buffer( + "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16) + ) if row_split: self.register_buffer( - 'g_idx', - torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], - dtype=torch.int32)) + "g_idx", + torch.tensor( + [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32 + ), + ) else: - self.register_buffer('g_idx', - torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + self.register_buffer( + "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32) + ) if bias: - self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) else: self.bias = None @@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module): self.row_split = row_split def pack(self, linear, scales, zeros, g_idx=None): - - g_idx = g_idx.clone() if g_idx is not None else torch.tensor( - [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + g_idx = ( + g_idx.clone() + if g_idx is not None + else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + ) scales = scales.t().contiguous() zeros = zeros.t().contiguous() @@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module): if linear.bias is not None: self.bias = linear.bias.clone().half() - wn = 8 pbits = 32 ptype = torch.int32 unsign_type = np.uint32 @@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module): intweight = [] for idx in range(self.infeatures): intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, - None]) + torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[ + :, None + ] + ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(unsign_type) @@ -109,7 +116,7 @@ class CaiQuantLinear(nn.Module): raise NotImplementedError("Only 2,4,8 bits are supported.") qweight = qweight.astype(sign_type) qweight1 = torch.from_numpy(qweight) - qweight1 = qweight1.contiguous() #.to("cuda") + qweight1 = qweight1.contiguous() # .to("cuda") self.qweight.data.copy_(qweight1) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) @@ -140,17 +147,20 @@ class CaiQuantLinear(nn.Module): self.q4_width = self.qweight.shape[1] if self.g_idx is not None: if self.row_split and torch.equal( - self.g_idx, - torch.tensor( - [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device)): + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device, + ), + ): self.g_idx = None elif torch.equal( - self.g_idx, - torch.tensor([i // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device)): + self.g_idx, + torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device + ), + ): self.g_idx = None if self.g_idx is not None: @@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module): outshape = x.shape[:-1] + (self.outfeatures,) if HAS_GPTQ_CUDA and self.bits == 4: - if self.q4 is None: self.init_q4() @@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module): def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): - qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) @@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1 zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num for i in range(split_num): - cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * - cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] - cai_linear.qzeros[:, i * zero_split_block:(i + 1) * - zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] - cai_linear.scales[:, i * cai_split_out_features:(i + 1) * - cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] + cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][ + :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] + cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][ + :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block + ] + cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][ + :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] if cai_linear.bias is not None: - cai_linear.bias[i * cai_split_out_features:(i + 1) * - cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] + cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][ + tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] cai_linear.g_idx.copy_(g_idx) def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): - qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) @@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): idx_split_features = cai_linear.infeatures // split_num for i in range(split_num): - cai_linear.qweight[i * cai_split_in_features:(i + 1) * - cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * - cai_split_in_features, :] - cai_linear.qzeros[i * zero_split_block:(i + 1) * - zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * - zero_split_block, :] - cai_linear.scales[i * zero_split_block:(i + 1) * - zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * - zero_split_block, :] - cai_linear.g_idx[i * idx_split_features:(i + 1) * - idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * - idx_split_features] + cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][ + tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, : + ] + cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][ + tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : + ] + cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][ + tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : + ] + cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][ + tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features + ] if cai_linear.bias is not None: cai_linear.bias.copy_(gptq_linear.bias) class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - - super().__init__(bits, - groupsize, - infeatures, - outfeatures, - bias, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=row_split) + super().__init__( + bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split + ) self.process_group = None @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: LazyInitContext.materialize(module) # get the attributes in_features = module.in_features # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - linear_1d = RowCaiQuantLinear(module.bits, - module.group_size, - module.in_features // tp_size, - module.out_features, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=True) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = RowCaiQuantLinear( + module.bits, + module.group_size, + module.in_features // tp_size, + module.out_features, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True, + ) linear_1d.process_group = process_group split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) @@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - - super().__init__(bits, - groupsize, - infeatures, - outfeatures, - bias, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=row_split) + super().__init__( + bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split + ) self.process_group = None @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: LazyInitContext.materialize(module) # get the attributes in_features = module.in_features # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - linear_1d = ColCaiQuantLinear(module.bits, - module.group_size, - module.in_features, - module.out_features // tp_size, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = ColCaiQuantLinear( + module.bits, + module.group_size, + module.in_features, + module.out_features // tp_size, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + ) linear_1d.process_group = process_group split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 3d9a23d2f..1725581d6 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -5,6 +5,7 @@ import torch try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -16,6 +17,7 @@ if HAS_TRITON: https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 """ if triton.__version__ < "2.1.0": + @triton.jit def _context_flash_attention_kernel( Q, @@ -131,29 +133,47 @@ if HAS_TRITON: m_i = m_i_new off_o = ( - (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + (cur_batch_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od ) out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return + else: # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11 @triton.jit def _context_flash_attention_kernel_2( - Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, - Out, - kv_group_num, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + Q, + K, + V, + sm_scale, + Alibi, + B_Start_Loc, + B_Seqlen, + Out, + kv_group_num, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) - + if kv_group_num is not None: cur_kv_head = cur_head // kv_group_num @@ -166,7 +186,11 @@ if HAS_TRITON: offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) if kv_group_num is None or kv_group_num == 1: off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd @@ -191,8 +215,11 @@ if HAS_TRITON: for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -220,8 +247,11 @@ if HAS_TRITON: acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -229,7 +259,11 @@ if HAS_TRITON: l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return @@ -249,7 +283,7 @@ if HAS_TRITON: grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) num_warps = 4 if Lk <= 64 else 8 - + if triton.__version__ < "2.1.0": tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) _context_flash_attention_kernel[grid]( @@ -286,20 +320,26 @@ if HAS_TRITON: ) else: _context_flash_attention_kernel_2[grid]( - q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + alibi, + b_start_loc, + b_seq_len, o, None, - q.stride(0), - q.stride(1), + q.stride(0), + q.stride(1), q.stride(2), - k.stride(0), - k.stride(1), + k.stride(0), + k.stride(1), k.stride(2), - v.stride(0), - v.stride(1), + v.stride(0), + v.stride(1), v.stride(2), - o.stride(0), - o.stride(1), + o.stride(0), + o.stride(1), o.stride(2), BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, @@ -307,7 +347,7 @@ if HAS_TRITON: num_warps=num_warps, num_stages=1, ) - + return @torch.no_grad() @@ -327,7 +367,7 @@ if HAS_TRITON: tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 - + if triton.__version__ < "2.1.0": _context_flash_attention_kernel[grid]( q, @@ -337,7 +377,7 @@ if HAS_TRITON: b_start_loc, b_seq_len, tmp, - None, + None, o, q.stride(0), q.stride(1), @@ -362,32 +402,33 @@ if HAS_TRITON: ) else: kv_group_num = q.shape[1] // k.shape[1] - _context_flash_attention_kernel_2[grid]( - q, - k, - v, - sm_scale, + _context_flash_attention_kernel_2[grid]( + q, + k, + v, + sm_scale, None, - b_start_loc, + b_start_loc, b_seq_len, o, kv_group_num, - q.stride(0), - q.stride(1), + q.stride(0), + q.stride(1), q.stride(2), - k.stride(0), - k.stride(1), + k.stride(0), + k.stride(1), k.stride(2), - v.stride(0), - v.stride(1), + v.stride(0), + v.stride(1), v.stride(2), - o.stride(0), - o.stride(1), + o.stride(0), + o.stride(1), o.stride(2), BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, - num_stages=1,) - - return \ No newline at end of file + num_stages=1, + ) + + return diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 9b7b27fa1..ac733dede 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -1,8 +1,10 @@ # adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py import torch + try: from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 + HAS_LIGHTLLM_KERNEL = True except: print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") @@ -10,41 +12,36 @@ except: if HAS_LIGHTLLM_KERNEL: + def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): BLOCK_SEQ = 256 batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, head_dim) - if getattr(infer_state, 'mid_o', None) is None: - infer_state.mid_o = torch.empty([batch_size, - q_head_num, - max_len_in_batch // BLOCK_SEQ + 1, - head_dim], - dtype=torch.float32, - device="cuda") - infer_state.mid_o_logexpsum = torch.empty([batch_size, - q_head_num, - max_len_in_batch // BLOCK_SEQ + 1], - dtype=torch.float32, - device="cuda") + if getattr(infer_state, "mid_o", None) is None: + infer_state.mid_o = torch.empty( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], + dtype=torch.float32, + device="cuda", + ) + infer_state.mid_o_logexpsum = torch.empty( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + ) mid_o = infer_state.mid_o mid_o_logexpsum = infer_state.mid_o_logexpsum - flash_decode_stage1(q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.block_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ) - flash_decode_stage2(mid_o, - mid_o_logexpsum, - infer_state.seq_len, - o_tensor.view(calcu_shape1), - BLOCK_SEQ) + flash_decode_stage1( + q.view(calcu_shape1), + cache_k, + cache_v, + infer_state.block_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + mid_o, + mid_o_logexpsum, + BLOCK_SEQ, + ) + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py index 45996c0dc..7a2c7e8fb 100644 --- a/colossalai/kernel/triton/llama_act_combine_kernel.py +++ b/colossalai/kernel/triton/llama_act_combine_kernel.py @@ -8,6 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -26,8 +27,8 @@ if HAS_TRITON: X_GATE2, X_UP, Y, - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. @@ -41,9 +42,9 @@ if HAS_TRITON: for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) - x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) - x_up = tl.load(X_UP + cols, mask=mask, other=0.) + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0) + x_up = tl.load(X_UP + cols, mask=mask, other=0.0) x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up # Write output @@ -58,8 +59,8 @@ if HAS_TRITON: X_GATE2_GRAD, X_UP_GRAD, Y_GRAD, - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. @@ -76,10 +77,10 @@ if HAS_TRITON: for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) - x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) - x_up = tl.load(X_UP + cols, mask=mask, other=0.) - y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.) + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.0) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.0) + x_up = tl.load(X_UP + cols, mask=mask, other=0.0) + y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.0) # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) @@ -147,14 +148,9 @@ if HAS_TRITON: # restore setting ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps # enqueue kernel - _llama_act_combine_forward[(M,)](x_gate1, - x_gate2, - x_up, - y, - x_up.stride(-2), - N, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps) + _llama_act_combine_forward[(M,)]( + x_gate1, x_gate2, x_up, y, x_up.stride(-2), N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + ) return y @staticmethod @@ -166,20 +162,25 @@ if HAS_TRITON: # init grad y_grad = grad_outputs[0] - x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like( - x_gate2), torch.empty_like(x_up) + x_gate1_grad, x_gate2_grad, x_up_grad = ( + torch.empty_like(x_gate1), + torch.empty_like(x_gate2), + torch.empty_like(x_up), + ) # enqueue kernel - _llama_act_combine_backward[(M,)](x_gate1, - x_gate2, - x_up, - x_gate1_grad, - x_gate2_grad, - x_up_grad, - y_grad, - x_up.stride(-2), - N, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps) + _llama_act_combine_backward[(M,)]( + x_gate1, + x_gate2, + x_up, + x_gate1_grad, + x_gate2_grad, + x_up_grad, + y_grad, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) return x_gate_grad, x_up_grad, None, None diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index de2003748..d8ac278c7 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -13,10 +13,18 @@ except ImportError: print("please install triton from https://github.com/openai/triton") try: - from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2 - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd - from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd - from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd + from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import ( + token_att_fwd as lightllm_bloom_token_att_fwd, + ) + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import ( + token_att_fwd as lightllm_llama_token_att_fwd, + ) + from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import ( + token_att_fwd2 as lightllm_llama_token_att_fwd2, + ) + from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import ( + token_softmax_fwd as lightllm_llama_token_softmax_fwd, + ) HAS_TRITON_TOKEN_ATTENTION = True except ImportError: @@ -205,9 +213,7 @@ class Llama2TokenAttentionForwards: if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) - lightllm_llama_token_softmax_fwd( - att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch - ) + lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None lightllm_llama_token_att_fwd2( diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/llama.py b/colossalai/legacy/inference/tensor_parallel/modeling/llama.py index 448943b12..21e5c5e78 100644 --- a/colossalai/legacy/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/legacy/inference/tensor_parallel/modeling/llama.py @@ -8,7 +8,9 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards + from ._utils import copy_kv_to_mem_cache + try: from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( context_attention_fwd as lightllm_llama_context_attention_fwd, @@ -90,7 +92,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) - + else: Llama2TokenAttentionForwards.token_attn( query_states, diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 0e368dbf9..7b8aa5380 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,5 +1,5 @@ -from .attn import AttnMaskType, ColoAttention from ._operation import all_to_all_comm +from .attn import AttnMaskType, ColoAttention from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index 6f8097735..4129ec62e 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -2,13 +2,13 @@ from .api import ( compute_global_numel, customized_distributed_tensor_to_param, distribute_tensor, - init_as_dtensor, distribute_tensor_with_customization, - init_tensor_as_customization_distributed, get_device_mesh, get_global_shape, get_layout, get_sharding_spec, + init_as_dtensor, + init_tensor_as_customization_distributed, is_customized_distributed_tensor, is_distributed_tensor, is_sharded, diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index da6ef275e..725817088 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -128,7 +128,10 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp return sharded_tensor -def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor: + +def init_as_dtensor( + tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size +) -> torch.Tensor: assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) @@ -140,6 +143,7 @@ def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec return tensor + def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: """ Convert the layout of the tensor from source_spec to target_spec. @@ -468,7 +472,6 @@ def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gat assert callable(gather_fn), "The gather_fn must be callable." assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." - # set the shard_fn and gather_fn as attributes of the distributed tensor tensor.shard_fn = shard_fn tensor.gather_fn = gather_fn diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index de08ecf3d..5ab703f09 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -190,6 +190,7 @@ def calculate_global_norm_from_list(norm_list): total_norm += norm**2.0 return math.sqrt(total_norm) + def sync_tensor(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md index 92775bafb..f32f6c367 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md @@ -220,7 +220,7 @@ model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost( ) ``` ## 使用混合并行训练 ViT -最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。 +最后就可以使用混合并行策略来训练模型了。我们先定义一个训练函数,描述训练过程。需要注意的是,如果使用了管道并行策略,需要调用`booster.execute_pipeline`来执行模型的训练,它会调用`scheduler`管理模型的前后向操作。 ```python def run_forward_backward( model: nn.Module, diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index 32b1ec803..fdae9ee01 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -119,9 +119,7 @@ def main(): if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: # run pipeline forward backward batch = iter([batch]) - outputs = booster.execute_pipeline( - batch, model, criterion, optimizer, return_loss=True - ) + outputs = booster.execute_pipeline(batch, model, criterion, optimizer, return_loss=True) else: outputs = model(**batch) loss = criterion(outputs, None) diff --git a/examples/language/data_utils.py b/examples/language/data_utils.py index ec849ef9d..6b9e8ef28 100644 --- a/examples/language/data_utils.py +++ b/examples/language/data_utils.py @@ -121,4 +121,4 @@ class RandomDataset(Dataset): "input_ids": self.input_ids[idx], "attention_mask": self.attention_mask[idx], "labels": self.input_ids[idx], - } \ No newline at end of file + } diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index 122186c30..69b4ebe42 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -270,9 +270,7 @@ def main(): ) as pbar: for step in pbar: if use_pipeline: - outputs = booster.execute_pipeline( - dataloader_iter, model, _criterion, optimizer, return_loss=True - ) + outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) loss = outputs["loss"] else: batch = next(dataloader_iter) diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index 7b5805b80..970cd5290 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -285,9 +285,7 @@ def main(): ) as pbar: for step in pbar: if use_pipeline: - outputs = booster.execute_pipeline( - dataloader_iter, model, _criterion, optimizer, return_loss=True - ) + outputs = booster.execute_pipeline(dataloader_iter, model, _criterion, optimizer, return_loss=True) loss = outputs["loss"] else: batch = next(dataloader_iter) diff --git a/examples/language/openmoe/benchmark/utils.py b/examples/language/openmoe/benchmark/utils.py index 7a0955bb0..096e06bd2 100644 --- a/examples/language/openmoe/benchmark/utils.py +++ b/examples/language/openmoe/benchmark/utils.py @@ -50,7 +50,6 @@ def all_reduce_mean(x: float, world_size: int) -> float: class Timer: - def __init__(self) -> None: self.start_time: Optional[float] = None self.duration: float = 0.0 @@ -112,7 +111,7 @@ class PerformanceEvaluator: batch_size, seq_len = input_ids.shape self.num_samples += batch_size - self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))) + self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) def on_fit_end(self) -> None: avg_duration = all_reduce_mean(self.timer.duration, self.world_size) @@ -122,5 +121,6 @@ class PerformanceEvaluator: if dist.get_rank() == 0: print( f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " - f"avg_throughput: {avg_throughput}") + f"avg_throughput: {avg_throughput}" + ) print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}") diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index 04df64531..50cdc63e8 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -16,17 +16,15 @@ def inference(args): tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") if args.model == "test": config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base") - set_openmoe_args(config, - num_experts=config.num_experts, - moe_layer_interval=config.moe_layer_interval, - enable_kernel=True) + set_openmoe_args( + config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=True + ) model = OpenMoeForCausalLM(config) else: config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}") - set_openmoe_args(config, - num_experts=config.num_experts, - moe_layer_interval=config.moe_layer_interval, - enable_kernel=False) + set_openmoe_args( + config, num_experts=config.num_experts, moe_layer_interval=config.moe_layer_interval, enable_kernel=False + ) model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config) model = model.eval().bfloat16() model = model.to(torch.cuda.current_device()) diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/examples/language/openmoe/model/convert_openmoe_ckpt.py index 20b1e780d..3e051850d 100644 --- a/examples/language/openmoe/model/convert_openmoe_ckpt.py +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.py @@ -172,9 +172,9 @@ def make_state_dict(converted_params): def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): """Replaces the params in model witht the T5X converted params.""" variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) - converted = convert_t5x_to_pytorch(variables, - num_layers=config.num_hidden_layers, - moe_interval=config.moe_layer_interval) + converted = convert_t5x_to_pytorch( + variables, num_layers=config.num_hidden_layers, moe_interval=config.moe_layer_interval + ) state_dict = make_state_dict(converted) model.load_state_dict(state_dict, strict=True) @@ -203,11 +203,9 @@ def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.") # Required parameters - parser.add_argument("--t5x_checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the T5X checkpoint.") + parser.add_argument( + "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path to the T5X checkpoint." + ) parser.add_argument( "--config_file", default=None, @@ -215,10 +213,8 @@ if __name__ == "__main__": required=True, help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.", ) - parser.add_argument("--pytorch_dump_path", - default=None, - type=str, - required=True, - help="Path to the output PyTorch model.") + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) args = parser.parse_args() convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index 82dff1920..05336bec4 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -41,9 +41,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b # Forward pass for _ in pbar: if use_pipeline: - outputs = booster.execute_pipeline( - dataloader, model, _criterion, optimizer, return_loss=True - ) + outputs = booster.execute_pipeline(dataloader, model, _criterion, optimizer, return_loss=True) # Backward and optimize if is_pp_last_stage: loss = outputs["loss"] diff --git a/extensions/cpu_adam/__init__.py b/extensions/cpu_adam/__init__.py index cfd26a6a2..d5c69e902 100644 --- a/extensions/cpu_adam/__init__.py +++ b/extensions/cpu_adam/__init__.py @@ -1,5 +1,4 @@ from .cpu_adam_arm import CpuAdamArmExtension from .cpu_adam_x86 import CpuAdamX86Extension -__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension'] - +__all__ = ["CpuAdamArmExtension", "CpuAdamX86Extension"] diff --git a/extensions/layernorm/__init__.py b/extensions/layernorm/__init__.py index 9d1bd2d01..30e6c68ef 100644 --- a/extensions/layernorm/__init__.py +++ b/extensions/layernorm/__init__.py @@ -1,3 +1,3 @@ from .layernorm_cuda import LayerNormCudaExtension -__all__ = ["LayerNormCudaExtension"] \ No newline at end of file +__all__ = ["LayerNormCudaExtension"] diff --git a/extensions/moe/__init__.py b/extensions/moe/__init__.py index 962084d4b..3b6aa24bf 100644 --- a/extensions/moe/__init__.py +++ b/extensions/moe/__init__.py @@ -1,3 +1,3 @@ from .moe_cuda import MoeCudaExtension -__all__ = ['MoeCudaExtension'] \ No newline at end of file +__all__ = ["MoeCudaExtension"] diff --git a/extensions/optimizer/__init__.py b/extensions/optimizer/__init__.py index 9c8e87cae..6a0c8d7b8 100644 --- a/extensions/optimizer/__init__.py +++ b/extensions/optimizer/__init__.py @@ -1,3 +1,3 @@ from .fused_optimizer_cuda import FusedOptimizerCudaExtension -__all__ = ['FusedOptimizerCudaExtension'] \ No newline at end of file +__all__ = ["FusedOptimizerCudaExtension"] diff --git a/extensions/softmax/__init__.py b/extensions/softmax/__init__.py index 9bc50c6cd..8833d93e7 100644 --- a/extensions/softmax/__init__.py +++ b/extensions/softmax/__init__.py @@ -1,4 +1,4 @@ from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension -__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension'] \ No newline at end of file +__all__ = ["ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension"] diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 5f6789ff3..66c794a7d 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,33 +1,33 @@ import os + from . import custom, diffusers, timm, torchaudio, torchvision, transformers from .executor import run_fwd, run_fwd_bwd from .registry import model_zoo # We pick a subset of models for fast testing in order to reduce the total testing time COMMON_MODELS = [ - 'custom_hanging_param_model', - 'custom_nested_model', - 'custom_repeated_computed_layers', - 'custom_simple_net', - 'diffusers_clip_text_model', - 'diffusers_auto_encoder_kl', - 'diffusers_unet2d_model', - 'timm_densenet', - 'timm_resnet', - 'timm_swin_transformer', - 'torchaudio_wav2vec2_base', - 'torchaudio_conformer', - 'transformers_bert_for_masked_lm', - 'transformers_bloom_for_causal_lm', - 'transformers_falcon_for_causal_lm', - 'transformers_chatglm_for_conditional_generation', - 'transformers_llama_for_casual_lm', - 'transformers_vit_for_masked_image_modeling', - 'transformers_mistral_for_casual_lm' + "custom_hanging_param_model", + "custom_nested_model", + "custom_repeated_computed_layers", + "custom_simple_net", + "diffusers_clip_text_model", + "diffusers_auto_encoder_kl", + "diffusers_unet2d_model", + "timm_densenet", + "timm_resnet", + "timm_swin_transformer", + "torchaudio_wav2vec2_base", + "torchaudio_conformer", + "transformers_bert_for_masked_lm", + "transformers_bloom_for_causal_lm", + "transformers_falcon_for_causal_lm", + "transformers_chatglm_for_conditional_generation", + "transformers_llama_for_casual_lm", + "transformers_vit_for_masked_image_modeling", + "transformers_mistral_for_casual_lm", ] -IS_FAST_TEST = os.environ.get('FAST_TEST', '0') == '1' +IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1" -__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", 'COMMON_MODELS', 'IS_FAST_TEST'] - +__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd", "COMMON_MODELS", "IS_FAST_TEST"] diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index fce81ab52..a16b16ad6 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -102,4 +102,4 @@ class ModelZooRegistry(dict): return new_dict -model_zoo = ModelZooRegistry() \ No newline at end of file +model_zoo = ModelZooRegistry() diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index e27fdb4e2..0b178d58c 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -2,6 +2,7 @@ import torch from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + from ..registry import ModelAttribute, model_zoo # ================================ diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 557666a80..d8a625b98 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -74,9 +74,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: - booster.execute_pipeline( - _preprocess_data(data), model, _criterion, optimizer, return_loss=True - ) + booster.execute_pipeline(_preprocess_data(data), model, _criterion, optimizer, return_loss=True) else: output = model(**_preprocess_data(data)) loss = criterion(output) @@ -108,9 +106,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf data_for_shard = data_gen_fn() data_for_origin = data_gen_fn() if booster.plugin.stage_manager is not None: - booster.execute_pipeline( - _preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True - ) + booster.execute_pipeline(_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True) booster.execute_pipeline( _preprocess_data(data_for_origin), new_model, diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index dca562a3b..1ea70368e 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -113,6 +113,7 @@ def check_torch_fsdp_ckpt(): full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) import copy + sharded_osd = copy.deepcopy(full_osd) run_model() diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index 9b650aa78..ded70fa43 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -1,16 +1,8 @@ -import math -import time - -import numpy as np import pytest import torch -import torch.nn as nn -import transformers from packaging import version try: - import triton - import triton.language as tl HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -22,6 +14,7 @@ try: from exllama_kernels import prepare_buffers, set_tuning_params from colossalai.inference.quant.gptq import CaiQuantLinear + HAS_AUTO_GPTQ = True except: HAS_AUTO_GPTQ = False @@ -32,13 +25,14 @@ import warnings HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") max_inner_outer_dim = 1 max_input_len = 1 @@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False): max_input_len = 4096 # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), - dtype=torch.float16, - device=torch.cuda.current_device()) + gptq_temp_state_buffer = torch.zeros( + (max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) @@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False): gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, - reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, + reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq", +) def test_gptq_linear(): - infeature = 1024 outfeature = 1024 group_size = 128 @@ -120,7 +115,7 @@ def test_gptq_linear(): max_input_len = 2048 buffers = { "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), - "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device), } prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) @@ -146,5 +141,4 @@ def test_gptq_linear(): if __name__ == "__main__": - test_gptq_linear() diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index aeca5f21d..d0c4cd0a7 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -24,4 +24,4 @@ def test_torchvision_models_lazy_init(subset, default_device): if __name__ == "__main__": - test_torchvision_models_lazy_init("transformers", "cpu") \ No newline at end of file + test_torchvision_models_lazy_init("transformers", "cpu") diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index 3315b3256..603b7b6fa 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -1,5 +1,5 @@ -import torch import pytest +import torch from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.testing import clear_cache_before_run, parameterize @@ -17,6 +17,7 @@ def check_params_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}" + # TODO Something wrong with ci when running this test. @pytest.mark.skip(reason="skip because of something wrong with CI") @clear_cache_before_run() diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 7aa464055..f8820688e 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -103,9 +103,7 @@ def run_pp( torch_loss = criterion(torch_output) torch_loss.backward() - pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True - ) + pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) # check loss if stage_manager.is_last_stage(ignore_chunk=True): diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index e1a679890..590800780 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -99,9 +99,7 @@ def examine_pp(num_microbatch: int, batch_size: int): torch_output = torch_model(input_list[0]) torch_loss = criterion(torch_output) torch_loss.backward() - pp_ret = schedule.forward_backward_step( - sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True - ) + pp_ret = schedule.forward_backward_step(sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True) # check loss if stage_manager.is_last_stage():