diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py index a2cfb2ef6..327651f4e 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py @@ -1,20 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import numpy as np import os -import random from dataclasses import dataclass -from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable +from typing import Dict, Iterator, List, Optional, Sequence, Union import torch -from datasets import dataset_dict, load_from_disk -from datasets import Dataset as HFDataset -from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import _get_default_group -from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler -from transformers.tokenization_utils import PreTrainedTokenizer import torch.nn.functional as F +from datasets import Dataset as HFDataset +from datasets import dataset_dict, load_from_disk +from torch.utils.data import ConcatDataset, Dataset, DistributedSampler +from transformers.tokenization_utils import PreTrainedTokenizer DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] PathType = Union[str, os.PathLike] @@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object): tokenizer: PreTrainedTokenizer max_length: int = 4096 ignore_index: int = -100 + padding: str = "max_length" def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: """ @@ -106,10 +103,11 @@ class DataCollatorForSupervisedDataset(object): batch_first=True, padding_value=self.ignore_index, ) # (bsz, max_len) - # pad to max - to_pad = self.max_length - input_ids.size(1) - input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) - labels = F.pad(labels, (0, to_pad), value=self.ignore_index) + if self.padding == "max_length": + # pad to max + to_pad = self.max_length - input_ids.size(1) + input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + labels = F.pad(labels, (0, to_pad), value=self.ignore_index) elif self.tokenizer.padding_side == "left": reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids] reversed_input_ids = torch.nn.utils.rnn.pad_sequence( @@ -171,49 +169,3 @@ class StatefulDistributedSampler(DistributedSampler): def set_start_index(self, start_index: int) -> None: self.start_index = start_index - - -def setup_distributed_dataloader( - dataset: DatasetType, - batch_size: int = 1, - shuffle: bool = False, - seed: int = 1024, - drop_last: bool = False, - pin_memory: bool = False, - num_workers: int = 0, - collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None, - process_group: Optional[ProcessGroup] = None, - **kwargs, -) -> DataLoader: - """ - Setup dataloader for distributed training. - """ - _kwargs = kwargs.copy() - process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler( - dataset=dataset, - num_replicas=process_group.size(), - rank=process_group.rank(), - shuffle=shuffle, - seed=seed, - drop_last=drop_last, - ) - - # Deterministic dataloader - def seed_worker(worker_id: int) -> None: - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - worker_init_fn=seed_worker, - **_kwargs, - ) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 1926ec78a..6c048c3b1 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import math from types import MethodType from typing import Optional, Tuple import torch +import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func -from flash_attn.ops.rms_norm import rms_norm +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, @@ -19,194 +19,334 @@ from transformers.models.llama.modeling_llama import ( repeat_kv, ) +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger logger = get_dist_logger() +if get_accelerator().name == "cuda": + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.ops.rms_norm import rms_norm -def _prepare_decoder_attention_mask( - self: LlamaModel, - attention_mask: torch.BoolTensor, - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, -) -> Optional[torch.Tensor]: - """ - Decoder attetion mask - """ - if past_key_values_length > 0 and attention_mask is not None: - attention_mask = torch.cat( - tensors=( - torch.full( - size=(input_shape[0], past_key_values_length), - fill_value=True, - dtype=attention_mask.dtype, - device=attention_mask.device, + def _prepare_decoder_attention_mask( + self: LlamaModel, + attention_mask: torch.BoolTensor, + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ) -> Optional[torch.Tensor]: + """ + Decoder attetion mask + """ + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + tensors=( + torch.full( + size=(input_shape[0], past_key_values_length), + fill_value=True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, ), - attention_mask, - ), - dim=-1, - ) # (bsz, past_key_values_length + q_len) - if attention_mask is not None and torch.all(attention_mask): - return None # Faster - return attention_mask - - -def attention_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. - """ - if output_attentions: - logger.warning( - "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " - "return `None` instead." - ) - - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - q_slicing, kv_slicing = ( - dim // self.config.pretraining_tp - for dim in ( - self.num_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, - ) - ) # `Tuple[int, int]` - q_slices, k_slices, v_slices = ( - proj.weight.split(slicing, dim=0) - for proj, slicing in ( - (self.q_proj, q_slicing), - (self.k_proj, kv_slicing), - (self.v_proj, kv_slicing), - ) - ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] - q, k, v = ( - torch.cat( - [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], dim=-1, + ) # (bsz, past_key_values_length + q_len) + if attention_mask is not None and torch.all(attention_mask): + return None # Faster + return attention_mask + + def attention_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. + """ + if output_attentions: + logger.warning( + "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " + "return `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + q_slicing, kv_slicing = ( + dim // self.config.pretraining_tp + for dim in ( + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ) + ) # `Tuple[int, int]` + q_slices, k_slices, v_slices = ( + proj.weight.split(slicing, dim=0) + for proj, slicing in ( + (self.q_proj, q_slicing), + (self.k_proj, kv_slicing), + (self.v_proj, kv_slicing), + ) + ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + q, k, v = ( + torch.cat( + [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + for slices in (q_slices, k_slices, v_slices) + ) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + else: + q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + + # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) + q, k, v = ( + states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) + for states, num_heads in ( + (q, self.num_heads), + (k, self.num_key_value_heads), + (v, self.num_key_value_heads), ) - for slices in (q_slices, k_slices, v_slices) ) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) - else: - q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) + kv_len = k.shape[-2] # initially, `kv_len` == `q_len` + past_kv_len = 0 + if past_key_value is not None: + # if `past_key_value` is not None, `kv_len` > `q_len`. + past_kv_len = past_key_value[0].shape[-2] + kv_len += past_kv_len - # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) - q, k, v = ( - states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) - for states, num_heads in ( - (q, self.num_heads), - (k, self.num_key_value_heads), - (v, self.num_key_value_heads), - ) - ) - kv_len = k.shape[-2] # initially, `kv_len` == `q_len` - past_kv_len = 0 - if past_key_value is not None: - # if `past_key_value` is not None, `kv_len` > `q_len`. - past_kv_len = past_key_value[0].shape[-2] - kv_len += past_kv_len + # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) + cos, sin = self.rotary_emb(v, seq_len=kv_len) + # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) + q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + k = torch.cat([past_key_value[0], k], dim=2) + v = torch.cat([past_key_value[1], v], dim=2) - # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) - cos, sin = self.rotary_emb(v, seq_len=kv_len) - # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) - q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) - if past_key_value is not None: - # reuse k, v, self_attention - k = torch.cat([past_key_value[0], k], dim=2) - v = torch.cat([past_key_value[1], v], dim=2) + past_key_value = (k, v) if use_cache else None - past_key_value = (k, v) if use_cache else None + # repeat k/v heads if n_kv_heads < n_heads + k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - # repeat k/v heads if n_kv_heads < n_heads - k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + key_padding_mask = attention_mask + # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) + q, k, v = (states.transpose(1, 2) for states in (q, k, v)) - key_padding_mask = attention_mask - # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) - q, k, v = (states.transpose(1, 2) for states in (q, k, v)) - - if past_kv_len > 0: - q = torch.cat( - tensors=( - torch.full( - size=(bsz, past_kv_len, self.num_heads, self.head_dim), - fill_value=0.0, - dtype=q.dtype, - device=q.device, + if past_kv_len > 0: + q = torch.cat( + tensors=( + torch.full( + size=(bsz, past_kv_len, self.num_heads, self.head_dim), + fill_value=0.0, + dtype=q.dtype, + device=q.device, + ), + q, ), - q, - ), - dim=1, - ) # (bsz, past_kv_len + q_len, num_heads, head_dim) + dim=1, + ) # (bsz, past_kv_len + q_len, num_heads, head_dim) - if key_padding_mask is None: - # (bsz, past_kv_len + q_len, num_heads, head_dim) - output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) - output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim) - else: - q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) - kv, _, cu_kv_lens, max_kv_len = unpad_input( - hidden_states=torch.stack(tensors=(k, v), dim=2), - attention_mask=key_padding_mask, - ) - output_unpad = flash_attn_varlen_kvpacked_func( - q=q, - kv=kv, - cu_seqlens_q=cu_q_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_q_len, - max_seqlen_k=max_kv_len, - dropout_p=0.0, - softmax_scale=None, - causal=True, - ) - output = pad_input( - hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), - indices=indices, - batch=bsz, - seqlen=past_kv_len + q_len, - ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + if key_padding_mask is None: + # (bsz, past_kv_len + q_len, num_heads, head_dim) + output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) + output = rearrange( + output, pattern="... h d -> ... (h d)" + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + else: + q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) + kv, _, cu_kv_lens, max_kv_len = unpad_input( + hidden_states=torch.stack(tensors=(k, v), dim=2), + attention_mask=key_padding_mask, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=cu_q_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_q_len, + max_seqlen_k=max_kv_len, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + output = pad_input( + hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), + indices=indices, + batch=bsz, + seqlen=past_kv_len + q_len, + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) - if past_kv_len > 0: - # Strip off the zero query outputs. - output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) - output = self.o_proj(output) # (bsz, q_len, hidden_size) - return output, None, past_key_value + if past_kv_len > 0: + # Strip off the zero query outputs. + output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) + output = self.o_proj(output) # (bsz, q_len, hidden_size) + return output, None, past_key_value + def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Formard function for RMS Norm + """ + return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) -def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Formard function for RMS Norm - """ - return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(attention_forward, module) + if isinstance(module, LlamaModel): + module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) + if isinstance(module, LlamaRMSNorm): + module.forward = MethodType(rms_norm_forward, module) +elif get_accelerator().name == "npu": + import torch_npu -def replace_with_flash_attention(model: LlamaForCausalLM) -> None: - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - module.forward = MethodType(attention_forward, module) - if isinstance(module, LlamaModel): - module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) - if isinstance(module, LlamaRMSNorm): - module.forward = MethodType(rms_norm_forward, module) + class NPULlamaAttention(LlamaAttention): + use_flash: bool = True + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.setup() + + def setup(self): + self._softmax_scale = 1 / math.sqrt(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if not self.use_flash: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + else: + attn_output, *_ = torch_npu.npu_fusion_attention( + query_states, + key_states, + value_states, + self.num_heads, + "BNSD", + atten_mask=attention_mask.bool(), + scale=self._softmax_scale, + padding_mask=None, + pre_tockens=65535, + next_tockens=0, + keep_prob=1.0, + inner_precise=0, + ) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum( + [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class NPURMSNorm(LlamaRMSNorm): + def forward(self, hidden_states): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.__class__ = NPULlamaAttention + module.setup() + if isinstance(module, LlamaRMSNorm): + module.__class__ = NPURMSNorm diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py index 9f6c9c1cc..21d769f3c 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py @@ -17,7 +17,7 @@ import torch def unwrap(model): if hasattr(model, "module"): - return unwrap_model(model.module) + return model.unwrap() else: return model diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh index 276d9ce99..6a1c887bf 100644 --- a/applications/Colossal-LLaMA-2/train.example.sh +++ b/applications/Colossal-LLaMA-2/train.example.sh @@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train. --warmup_steps 100 \ --use_grad_checkpoint \ --use_flash_attn \ + --pad_token "unk" diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 92863e8e4..20ec2a7c8 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team """ import argparse @@ -16,22 +16,24 @@ from colossal_llama2.dataset.loader import ( DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset, - setup_distributed_dataloader, ) from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention from colossal_llama2.utils.froze import freeze_non_embeds_parameters +from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device def get_model_numel(model: torch.nn.Module) -> int: @@ -83,6 +85,7 @@ def main() -> None: parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") parser.add_argument("--config_file", type=str, default="config_file", help="Config file") parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") parser.add_argument("--max_length", type=int, default=4096, help="Model max length") @@ -108,6 +111,12 @@ def main() -> None: default=False, help="Use flash-attention", ) + parser.add_argument( + "--use_neft", + action="store_true", + default=False, + help="Use NEFTune", + ) parser.add_argument( "--freeze_non_embeds_params", action="store_true", @@ -116,6 +125,8 @@ def main() -> None: ) parser.add_argument("--tp", type=int, default=1) parser.add_argument("--zero", type=int, default=1) + parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") + parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") args = parser.parse_args() with open(args.config_file, "w") as f: @@ -125,6 +136,7 @@ def main() -> None: # Initialize Distributed Training # ============================== colossalai.launch_from_torch({}) + accelerator = get_accelerator() coordinator = DistCoordinator() # ============================== @@ -182,7 +194,10 @@ def main() -> None: # Initialize Tokenizer, Dataset, Collator and Dataloader # ====================================================== tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) - tokenizer.pad_token = tokenizer.unk_token + if args.pad_token == "eos": + tokenizer.pad_token = tokenizer.eos_token + elif args.pad_token == "unk": + tokenizer.pad_token = tokenizer.unk_token tokenizer.add_bos_token = False tokenizer.add_eos_token = False @@ -193,38 +208,36 @@ def main() -> None: coordinator.print_on_master(f"Load dataset: {args.dataset}") dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) - dataloader = setup_distributed_dataloader( + data_collator = DataCollatorForSupervisedDataset( + tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode + ) + dataloader = plugin.prepare_dataloader( dataset=dataset, batch_size=args.micro_batch_size, shuffle=True, drop_last=True, collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, ) coordinator.print_on_master( - f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" ) # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== - - # colossalai has changed api for get_current_device in 0.3.4 version or newer - try: - from colossalai.accelerator import get_accelerator - - current_device = get_accelerator().get_current_device() - except: - from colossalai.utils import get_current_device - - current_device = get_current_device() - - init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) with init_ctx: - model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) + model = LlamaForCausalLM.from_pretrained(args.pretrained) # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) + # this is essential, otherwise the grad checkpoint will not work. + model.train() if args.use_grad_checkpoint: model.gradient_checkpointing_enable() @@ -246,12 +259,14 @@ def main() -> None: adamw_mode=True, ) + if args.warmup_steps is None: + args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + lr_scheduler = CosineAnnealingWarmupLR( optimizer=optimizer, - total_steps=args.num_epochs * len(dataloader), - warmup_steps=args.warmup_steps - if args.warmup_steps is not None - else int(args.num_epochs * len(dataloader) * 0.025), + total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps), + warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr, ) @@ -267,11 +282,9 @@ def main() -> None: torch.set_default_dtype(torch.float) - if args.load_checkpoint is None: - coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}") - booster.load_model(model, args.pretrained, strict=False) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" ) @@ -298,85 +311,109 @@ def main() -> None: coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") coordinator.print_on_master( - f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" ) coordinator.print_on_master( - f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB" ) coordinator.print_on_master( f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" ) - num_steps_per_epoch = len(dataloader) + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + num_steps_per_epoch = len(dataloader) // args.accumulation_steps # If resume training, set the sampler start index to the correct value assert isinstance(dataloader.sampler, StatefulDistributedSampler) dataloader.sampler.set_start_index(start_index=sampler_start_idx) for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch=epoch) - with tqdm( - iterable=enumerate(dataloader, start=start_step), + pbar = tqdm( desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step, batch in pbar: - batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} + initial=start_step // args.accumulation_steps, + ) + total_loss = torch.tensor(0.0, device=get_current_device()) + for step, batch in enumerate(dataloader, start=start_step): + batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} - batch_output = model(**batch) + batch_output = model(**batch) - loss = batch_output.loss + loss = batch_output.loss / args.accumulation_steps + total_loss.add_(loss.data) - booster.backward(loss=loss, optimizer=optimizer) + booster.backward(loss=loss, optimizer=optimizer) + if (step + 1) % args.accumulation_steps == 0: optimizer.step() lr_scheduler.step() optimizer.zero_grad() - all_reduce_mean(tensor=loss) - pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) + all_reduce_mean(tensor=total_loss) + pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) if coordinator.is_master(): - global_step = epoch * num_steps_per_epoch + step - writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step) + global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) writer.add_scalar( tag="Learning Rate", scalar_value=lr_scheduler.get_last_lr()[0], global_step=global_step, ) - # Save modeling. + total_loss.fill_(0.0) + pbar.update() + # Save modeling. - if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader): - coordinator.print_on_master("\nStart saving model checkpoint with running states") - save_checkpoint( - save_dir=args.save_dir, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - epoch=epoch, - step=step + 1, - batch_size=args.micro_batch_size, - coordinator=coordinator, - ) - coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" - ) + if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( + step + 1 + ) == len(dataloader): + coordinator.print_on_master("\nStart saving model checkpoint with running states") - # Delete CUDA cache. - # del batch, batch_labels, batch_output, loss - torch.cuda.empty_cache() + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune before saving model.") + deactivate_neftune(model, handle) + + accelerator.empty_cache() + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.micro_batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) + + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + # Delete cache. + # del batch, batch_labels, batch_output, loss + accelerator.empty_cache() # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(start_index=0) start_step = 0 + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune.") + deactivate_neftune(model, handle) + # Final save. coordinator.print_on_master("Start saving final model checkpoint") booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/applications/Colossal-LLaMA-2/train_sft.example.sh b/applications/Colossal-LLaMA-2/train_sft.example.sh index dcb11515d..d87f9ef82 100755 --- a/applications/Colossal-LLaMA-2/train_sft.example.sh +++ b/applications/Colossal-LLaMA-2/train_sft.example.sh @@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}" CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" -colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \ +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \ --pretrained $PRETRAINED_MODEL_PATH \ --dataset ${dataset[@]} \ --plugin "zero2" \ @@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_ --use_grad_checkpoint \ --use_flash_attn \ --use_neft \ + --pad_token "eos" diff --git a/applications/Colossal-LLaMA-2/train_sft.py b/applications/Colossal-LLaMA-2/train_sft.py deleted file mode 100644 index fd9e1cd3e..000000000 --- a/applications/Colossal-LLaMA-2/train_sft.py +++ /dev/null @@ -1,403 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team -""" - -import argparse -import json -import os -import resource -from contextlib import nullcontext - -import torch -import torch.distributed as dist -from colossal_llama2.dataset.loader import ( - DataCollatorForSupervisedDataset, - StatefulDistributedSampler, - load_tokenized_dataset, - setup_distributed_dataloader, -) -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.froze import freeze_non_embeds_parameters -from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device - - -def get_model_numel(model: torch.nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor.div_(dist.get_world_size()) - return tensor - - -def main() -> None: - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument( - "--pretrained", - type=str, - default=None, - help="Address of the pre-trained modeling", - ) - parser.add_argument("--dataset", nargs="+", default=[]) - parser.add_argument( - "--plugin", - type=str, - default="gemini", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], - help="Choose which plugin to use", - ) - parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") - parser.add_argument("--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") - parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") - parser.add_argument("--config_file", type=str, default="config_file", help="Config file") - parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") - parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps") - parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=4096, help="Model max length") - parser.add_argument( - "--mixed_precision", - type=str, - default="fp16", - choices=["fp16", "bf16"], - help="Mixed precision", - ) - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") - parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") - parser.add_argument( - "--use_grad_checkpoint", - action="store_true", - default=False, - help="Use gradient checkpointing", - ) - parser.add_argument( - "--use_flash_attn", - action="store_true", - default=False, - help="Use flash-attention", - ) - parser.add_argument( - "--use_neft", - action="store_true", - default=False, - help="Use NEFTune", - ) - parser.add_argument( - "--freeze_non_embeds_params", - action="store_true", - default=False, - help="Freeze non embeddings parameters", - ) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--zero", type=int, default=1) - args = parser.parse_args() - - with open(args.config_file, "w") as f: - json.dump(args.__dict__, f, indent=4) - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Tensorboard - # ============================== - if coordinator.is_master(): - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin( - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, - placement_policy="auto", - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - cpu_offload=True, - max_norm=args.grad_clip, - ) - elif args.plugin == "3d": - plugin = HybridParallelPlugin( - tp_size=args.tp, - pp_size=1, - zero_stage=args.zero, - max_norm=args.grad_clip, - precision=args.mixed_precision, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - # ====================================================== - # Initialize Tokenizer, Dataset, Collator and Dataloader - # ====================================================== - tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.add_bos_token = False - tokenizer.add_eos_token = False - - coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") - coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}") - coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}") - - coordinator.print_on_master(f"Load dataset: {args.dataset}") - - dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) - dataloader = setup_distributed_dataloader( - dataset=dataset, - batch_size=args.micro_batch_size, - shuffle=True, - drop_last=True, - collate_fn=data_collator, - ) - coordinator.print_on_master( - f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" - ) - - # ====================================================== - # Initialize Model, Objective, Optimizer and LR Scheduler - # ====================================================== - init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() - ) - with init_ctx: - model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) - # Freeze part of parameters. - if args.freeze_non_embeds_params: - freeze_non_embeds_parameters(model=model) - - if args.use_grad_checkpoint: - model.gradient_checkpointing_enable() - coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") - if args.use_flash_attn: - replace_with_flash_attention(model=model) - coordinator.print_on_master(msg="Flash-attention enabled successfully") - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam( - model_params=filter(lambda p: p.requires_grad, model.parameters()) - if args.freeze_non_embeds_params - else model.parameters(), - lr=args.lr, - betas=(0.9, 0.95), - weight_decay=args.weight_decay, - adamw_mode=True, - ) - - if args.warmup_steps is None: - args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps)) - coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") - - lr_scheduler = CosineAnnealingWarmupLR( - optimizer=optimizer, - total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps), - warmup_steps=args.warmup_steps, - eta_min=0.1 * args.lr, - ) - - # Flash attention will be disabled because it does NOT support fp32. - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - dataloader=dataloader, - ) - - torch.set_default_dtype(torch.float) - - if args.load_checkpoint is None: - coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}") - booster.load_model(model, args.pretrained, strict=False) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" - ) - - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load_checkpoint is not None: - if "modeling" in args.load_checkpoint: - coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}") - booster.load_model(model, args.load_checkpoint) - else: - coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}") - start_epoch, start_step, sampler_start_idx = load_checkpoint( - load_dir=args.load_checkpoint, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - ) - coordinator.print_on_master( - f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}" - ) - coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") - - coordinator.print_on_master( - f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" - ) - coordinator.print_on_master( - f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" - ) - coordinator.print_on_master( - f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" - ) - - if args.use_neft: - coordinator.print_on_master("Activate NEFTune.") - model, handle = activate_neftune(model) - - num_steps_per_epoch = len(dataloader) // args.accumulation_steps - # If resume training, set the sampler start index to the correct value - assert isinstance(dataloader.sampler, StatefulDistributedSampler) - dataloader.sampler.set_start_index(start_index=sampler_start_idx) - - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch=epoch) - pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch) - total_loss = torch.tensor(0.0).to(torch.cuda.current_device()) - for step, batch in enumerate(dataloader): - batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} - - batch_output = model(**batch) - - loss = batch_output.loss / args.accumulation_steps - total_loss += loss.item() - - booster.backward(loss=loss, optimizer=optimizer) - - if (step + 1) % args.accumulation_steps == 0: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - all_reduce_mean(tensor=total_loss) - pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) - if coordinator.is_master(): - global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps - writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) - writer.add_scalar( - tag="Learning Rate", - scalar_value=lr_scheduler.get_last_lr()[0], - global_step=global_step, - ) - total_loss.fill_(0.0) - pbar.update() - # Save modeling. - - if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( - step + 1 - ) == len(dataloader): - coordinator.print_on_master("\nStart saving model checkpoint with running states") - - if args.use_neft: - coordinator.print_on_master("Deactivate NEFTune before saving model.") - deactivate_neftune(model, handle) - - save_checkpoint( - save_dir=args.save_dir, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - epoch=epoch, - step=step + 1, - batch_size=args.micro_batch_size, - coordinator=coordinator, - ) - coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" - ) - - if args.use_neft: - coordinator.print_on_master("Activate NEFTune.") - model, handle = activate_neftune(model) - - # Delete CUDA cache. - # del batch, batch_labels, batch_output, loss - torch.cuda.empty_cache() - - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(start_index=0) - start_step = 0 - - if args.use_neft: - coordinator.print_on_master("Deactivate NEFTune.") - deactivate_neftune(model, handle) - - # Final save. - coordinator.print_on_master("Start saving final model checkpoint") - booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py index f293c4f69..9c70c0d2a 100644 --- a/applications/ColossalEval/colossal_eval/models/chatglm.py +++ b/applications/ColossalEval/colossal_eval/models/chatglm.py @@ -3,6 +3,8 @@ from typing import List import torch +from colossalai.utils import get_current_device + from .huggingface import HuggingFaceModel IGNORE_INDEX = -100 @@ -126,9 +128,9 @@ class ChatGLMModel(HuggingFaceModel): """ input_ids = torch.nn.utils.rnn.pad_sequence( input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id - ).to(torch.cuda.current_device()) + ).to(get_current_device()) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( - torch.cuda.current_device() + get_current_device() ) outputs = self.model(input_ids)[0] @@ -197,7 +199,7 @@ class ChatGLM2Model(ChatGLMModel): truncation=True, return_tensors="pt", max_length=self.model_max_length - max_new_tokens, - ).to(torch.cuda.current_device()) + ).to(get_current_device()) # Set output_scores=True to get prediction scores. outputs = self.model.generate( diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 741c884f0..fff697e21 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokeni from colossalai.logging import DistributedLogger from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.utils import get_current_device from .base import BaseModel @@ -128,12 +129,12 @@ class HuggingFaceModel(BaseModel): self.model = AutoModel.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) self.model, sharded_parameters = shard_former.optimize(self.model) - self.model.to(torch.cuda.current_device()) + self.model.to(get_current_device()) if peft_path is not None: raise NotImplementedError("ShardFormer for PEFT models is not implemented.") else: - self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device()) if peft_path is not None: self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) self.model.eval() @@ -155,11 +156,11 @@ class HuggingFaceModel(BaseModel): """ input_ids = torch.nn.utils.rnn.pad_sequence( input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id - ).to(torch.cuda.current_device()) + ).to(get_current_device()) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( - torch.cuda.current_device() + get_current_device() ) - attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device()) + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device()) outputs = self.model(input_ids, attention_mask=attention_mask)[0] @@ -464,7 +465,7 @@ class HuggingFaceModel(BaseModel): return_tensors="pt", return_token_type_ids=False, max_length=self.model_max_length - max_new_tokens, - ).to(torch.cuda.current_device()) + ).to(get_current_device()) # Set output_scores=True to get prediction scores. outputs = self.model.generate( @@ -598,12 +599,12 @@ class HuggingFaceCausalLM(HuggingFaceModel): self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) self.model, sharded_parameters = shard_former.optimize(self.model) - self.model.to(torch.cuda.current_device()) + self.model.to(get_current_device()) if peft_path is not None: raise NotImplementedError("ShardFormer for PEFT models is not implemented.") else: - self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device()) if peft_path is not None: self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 5b09f9de8..a340f3bfd 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -8,6 +8,7 @@ import torch.distributed as dist from colossal_eval import dataset, models, utils import colossalai +from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig @@ -82,6 +83,7 @@ def rm_and_merge( def main(args): colossalai.launch_from_torch(config={}, seed=42) + accelerator = get_accelerator() world_size = dist.get_world_size() rank = dist.get_rank() @@ -235,10 +237,10 @@ def main(args): ), ) - logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") + logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB") del model_ - torch.cuda.empty_cache() + accelerator.empty_cache() dist.barrier() if rank == 0: diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py index d2dd00453..27285f95c 100644 --- a/colossalai/booster/plugin/dp_plugin_base.py +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -21,7 +21,16 @@ class DPPluginBase(Plugin): self.world_size = dist.get_world_size() def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -45,7 +54,8 @@ class DPPluginBase(Plugin): :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) # Deterministic dataloader def seed_worker(worker_id): diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index d14109dd4..95b96bbfd 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase): return ["cuda", "npu"] def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase): extra_dp_world_size = self.pg_mesh.size(DP_AXIS) zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) - sampler = DistributedSampler( + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 943e137e6..da67e6b41 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase): return outputs def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase): :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler( + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle ) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 5f832f13c..36df30335 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): tp_group=self.tp_group, use_zero=self.use_zero, inplace=False, - device=torch.device("cuda"), + device=get_current_device(), ) if self.pp_size == 1: @@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if isinstance(v, torch.Tensor) and k != "step": # First gather Zero shards. if use_zero: - v = v.cuda() + v = v.to(get_current_device()) gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] dist.all_gather(gather_tensor, v, group=dp_group) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)