mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5377 from hpcaitech/example/llama-npu
[llama] support npu for Colossal-LLaMA-2pull/5380/head
commit
4c03347fc7
|
@ -1,20 +1,16 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from datasets import Dataset as HFDataset
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
PathType = Union[str, os.PathLike]
|
||||
|
@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
|
|||
tokenizer: PreTrainedTokenizer
|
||||
max_length: int = 4096
|
||||
ignore_index: int = -100
|
||||
padding: str = "max_length"
|
||||
|
||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
@ -106,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
|
|||
batch_first=True,
|
||||
padding_value=self.ignore_index,
|
||||
) # (bsz, max_len)
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
if self.padding == "max_length":
|
||||
# pad to max
|
||||
to_pad = self.max_length - input_ids.size(1)
|
||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||
elif self.tokenizer.padding_side == "left":
|
||||
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
|
||||
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
|
@ -171,49 +169,3 @@ class StatefulDistributedSampler(DistributedSampler):
|
|||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def setup_distributed_dataloader(
|
||||
dataset: DatasetType,
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**kwargs,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Setup dataloader for distributed training.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset=dataset,
|
||||
num_replicas=process_group.size(),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id: int) -> None:
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs,
|
||||
)
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaForCausalLM,
|
||||
|
@ -19,194 +19,334 @@ from transformers.models.llama.modeling_llama import (
|
|||
repeat_kv,
|
||||
)
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
if get_accelerator().name == "cuda":
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
def _prepare_decoder_attention_mask(
|
||||
self: LlamaModel,
|
||||
attention_mask: torch.BoolTensor,
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Decoder attetion mask
|
||||
"""
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(input_shape[0], past_key_values_length),
|
||||
fill_value=True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
attention_mask,
|
||||
),
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
) # (bsz, past_key_values_length + q_len)
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # Faster
|
||||
return attention_mask
|
||||
|
||||
def attention_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
"""
|
||||
if output_attentions:
|
||||
logger.warning(
|
||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||
"return `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
q_slicing, kv_slicing = (
|
||||
dim // self.config.pretraining_tp
|
||||
for dim in (
|
||||
self.num_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
)
|
||||
) # `Tuple[int, int]`
|
||||
q_slices, k_slices, v_slices = (
|
||||
proj.weight.split(slicing, dim=0)
|
||||
for proj, slicing in (
|
||||
(self.q_proj, q_slicing),
|
||||
(self.k_proj, kv_slicing),
|
||||
(self.v_proj, kv_slicing),
|
||||
)
|
||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||
q, k, v = (
|
||||
torch.cat(
|
||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||
dim=-1,
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
for slices in (q_slices, k_slices, v_slices)
|
||||
)
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
else:
|
||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||
# (bsz, q_len, num_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
|
||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k, v = (
|
||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||
for states, num_heads in (
|
||||
(q, self.num_heads),
|
||||
(k, self.num_key_value_heads),
|
||||
(v, self.num_key_value_heads),
|
||||
)
|
||||
)
|
||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||
past_kv_len = 0
|
||||
if past_key_value is not None:
|
||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||
past_kv_len = past_key_value[0].shape[-2]
|
||||
kv_len += past_kv_len
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
|
||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
k = torch.cat([past_key_value[0], k], dim=2)
|
||||
v = torch.cat([past_key_value[1], v], dim=2)
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
|
||||
past_key_value = (k, v) if use_cache else None
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
key_padding_mask = attention_mask
|
||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
if past_kv_len > 0:
|
||||
q = torch.cat(
|
||||
tensors=(
|
||||
torch.full(
|
||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||
fill_value=0.0,
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
),
|
||||
q,
|
||||
),
|
||||
q,
|
||||
),
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
dim=1,
|
||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
if key_padding_mask is None:
|
||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||
output = rearrange(
|
||||
output, pattern="... h d -> ... (h d)"
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
else:
|
||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||
attention_mask=key_padding_mask,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q=q,
|
||||
kv=kv,
|
||||
cu_seqlens_q=cu_q_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
max_seqlen_q=max_q_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = pad_input(
|
||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||
indices=indices,
|
||||
batch=bsz,
|
||||
seqlen=past_kv_len + q_len,
|
||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
if past_kv_len > 0:
|
||||
# Strip off the zero query outputs.
|
||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||
return output, None, past_key_value
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
|
||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Formard function for RMS Norm
|
||||
"""
|
||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
||||
|
||||
elif get_accelerator().name == "npu":
|
||||
import torch_npu
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.forward = MethodType(attention_forward, module)
|
||||
if isinstance(module, LlamaModel):
|
||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.forward = MethodType(rms_norm_forward, module)
|
||||
class NPULlamaAttention(LlamaAttention):
|
||||
use_flash: bool = True
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self._softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if not self.use_flash:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
else:
|
||||
attn_output, *_ = torch_npu.npu_fusion_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self.num_heads,
|
||||
"BNSD",
|
||||
atten_mask=attention_mask.bool(),
|
||||
scale=self._softmax_scale,
|
||||
padding_mask=None,
|
||||
pre_tockens=65535,
|
||||
next_tockens=0,
|
||||
keep_prob=1.0,
|
||||
inner_precise=0,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum(
|
||||
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class NPURMSNorm(LlamaRMSNorm):
|
||||
def forward(self, hidden_states):
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
|
||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
module.__class__ = NPULlamaAttention
|
||||
module.setup()
|
||||
if isinstance(module, LlamaRMSNorm):
|
||||
module.__class__ = NPURMSNorm
|
||||
|
|
|
@ -17,7 +17,7 @@ import torch
|
|||
|
||||
def unwrap(model):
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
return model.unwrap()
|
||||
else:
|
||||
return model
|
||||
|
||||
|
|
|
@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.
|
|||
--warmup_steps 100 \
|
||||
--use_grad_checkpoint \
|
||||
--use_flash_attn \
|
||||
--pad_token "unk"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -16,22 +16,24 @@ from colossal_llama2.dataset.loader import (
|
|||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
|
@ -83,6 +85,7 @@ def main() -> None:
|
|||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
||||
|
@ -108,6 +111,12 @@ def main() -> None:
|
|||
default=False,
|
||||
help="Use flash-attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_neft",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use NEFTune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_non_embeds_params",
|
||||
action="store_true",
|
||||
|
@ -116,6 +125,8 @@ def main() -> None:
|
|||
)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=1)
|
||||
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
||||
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
|
@ -125,6 +136,7 @@ def main() -> None:
|
|||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
accelerator = get_accelerator()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
|
@ -182,7 +194,10 @@ def main() -> None:
|
|||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||
# ======================================================
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
if args.pad_token == "eos":
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.pad_token == "unk":
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
|
@ -193,38 +208,36 @@ def main() -> None:
|
|||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = setup_distributed_dataloader(
|
||||
data_collator = DataCollatorForSupervisedDataset(
|
||||
tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
|
||||
)
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.micro_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
|
||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
||||
try:
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
current_device = get_accelerator().get_current_device()
|
||||
except:
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
current_device = get_current_device()
|
||||
|
||||
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
||||
# Freeze part of parameters.
|
||||
if args.freeze_non_embeds_params:
|
||||
freeze_non_embeds_parameters(model=model)
|
||||
# this is essential, otherwise the grad checkpoint will not work.
|
||||
model.train()
|
||||
|
||||
if args.use_grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
@ -246,12 +259,14 @@ def main() -> None:
|
|||
adamw_mode=True,
|
||||
)
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * len(dataloader),
|
||||
warmup_steps=args.warmup_steps
|
||||
if args.warmup_steps is not None
|
||||
else int(args.num_epochs * len(dataloader) * 0.025),
|
||||
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
|
@ -267,11 +282,9 @@ def main() -> None:
|
|||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
if args.load_checkpoint is None:
|
||||
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
|
||||
booster.load_model(model, args.pretrained, strict=False)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
@ -298,85 +311,109 @@ def main() -> None:
|
|||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
num_steps_per_epoch = len(dataloader)
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
|
||||
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
||||
# If resume training, set the sampler start index to the correct value
|
||||
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
||||
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
with tqdm(
|
||||
iterable=enumerate(dataloader, start=start_step),
|
||||
pbar = tqdm(
|
||||
desc=f"Epoch {epoch}",
|
||||
disable=not coordinator.is_master(),
|
||||
total=num_steps_per_epoch,
|
||||
initial=start_step,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
initial=start_step // args.accumulation_steps,
|
||||
)
|
||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||
for step, batch in enumerate(dataloader, start=start_step):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
batch_output = model(**batch)
|
||||
|
||||
loss = batch_output.loss
|
||||
loss = batch_output.loss / args.accumulation_steps
|
||||
total_loss.add_(loss.data)
|
||||
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
all_reduce_mean(tensor=loss)
|
||||
pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
|
||||
all_reduce_mean(tensor=total_loss)
|
||||
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||
if coordinator.is_master():
|
||||
global_step = epoch * num_steps_per_epoch + step
|
||||
writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
|
||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
tag="Learning Rate",
|
||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||
global_step=global_step,
|
||||
)
|
||||
# Save modeling.
|
||||
total_loss.fill_(0.0)
|
||||
pbar.update()
|
||||
# Save modeling.
|
||||
|
||||
if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.micro_batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||
)
|
||||
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||
step + 1
|
||||
) == len(dataloader):
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
|
||||
# Delete CUDA cache.
|
||||
# del batch, batch_labels, batch_output, loss
|
||||
torch.cuda.empty_cache()
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
accelerator.empty_cache()
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.micro_batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
|
||||
# Delete cache.
|
||||
# del batch, batch_labels, batch_output, loss
|
||||
accelerator.empty_cache()
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(start_index=0)
|
||||
start_step = 0
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
|||
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
|
||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \
|
||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
|
||||
--pretrained $PRETRAINED_MODEL_PATH \
|
||||
--dataset ${dataset[@]} \
|
||||
--plugin "zero2" \
|
||||
|
@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_
|
|||
--use_grad_checkpoint \
|
||||
--use_flash_attn \
|
||||
--use_neft \
|
||||
--pad_token "eos"
|
||||
|
|
|
@ -1,403 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_llama2.dataset.loader import (
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
if numel >= B:
|
||||
return f"{numel / B:.2f} B"
|
||||
elif numel >= M:
|
||||
return f"{numel / M:.2f} M"
|
||||
elif numel >= K:
|
||||
return f"{numel / K:.2f} K"
|
||||
else:
|
||||
return f"{numel}"
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Address of the pre-trained modeling",
|
||||
)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="gemini",
|
||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||
help="Choose which plugin to use",
|
||||
)
|
||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
||||
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16"],
|
||||
help="Mixed precision",
|
||||
)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument(
|
||||
"--use_grad_checkpoint",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use gradient checkpointing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use flash-attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_neft",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use NEFTune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze_non_embeds_params",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Freeze non embeddings parameters",
|
||||
)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
||||
# ==============================
|
||||
# Initialize Distributed Training
|
||||
# ==============================
|
||||
colossalai.launch_from_torch({})
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Initialize Tensorboard
|
||||
# ==============================
|
||||
if coordinator.is_master():
|
||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
if args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2_cpu":
|
||||
plugin = LowLevelZeroPlugin(
|
||||
stage=2,
|
||||
precision=args.mixed_precision,
|
||||
initial_scale=2**16,
|
||||
cpu_offload=True,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=args.zero,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||
# ======================================================
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.add_bos_token = False
|
||||
tokenizer.add_eos_token = False
|
||||
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||
|
||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
dataloader = setup_distributed_dataloader(
|
||||
dataset=dataset,
|
||||
batch_size=args.micro_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
)
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
# Freeze part of parameters.
|
||||
if args.freeze_non_embeds_params:
|
||||
freeze_non_embeds_parameters(model=model)
|
||||
|
||||
if args.use_grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
if args.use_flash_attn:
|
||||
replace_with_flash_attention(model=model)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
|
||||
model_numel = get_model_numel(model)
|
||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||
|
||||
optimizer = HybridAdam(
|
||||
model_params=filter(lambda p: p.requires_grad, model.parameters())
|
||||
if args.freeze_non_embeds_params
|
||||
else model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
||||
warmup_steps=args.warmup_steps,
|
||||
eta_min=0.1 * args.lr,
|
||||
)
|
||||
|
||||
# Flash attention will be disabled because it does NOT support fp32.
|
||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||
torch.set_default_dtype(default_dtype)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=dataloader,
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
if args.load_checkpoint is None:
|
||||
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
|
||||
booster.load_model(model, args.pretrained, strict=False)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
start_step = 0
|
||||
sampler_start_idx = 0
|
||||
if args.load_checkpoint is not None:
|
||||
if "modeling" in args.load_checkpoint:
|
||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
|
||||
booster.load_model(model, args.load_checkpoint)
|
||||
else:
|
||||
coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
|
||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
||||
load_dir=args.load_checkpoint,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
|
||||
)
|
||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||
)
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
|
||||
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
||||
# If resume training, set the sampler start index to the correct value
|
||||
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
||||
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||
|
||||
for epoch in range(start_epoch, args.num_epochs):
|
||||
dataloader.sampler.set_epoch(epoch=epoch)
|
||||
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
|
||||
total_loss = torch.tensor(0.0).to(torch.cuda.current_device())
|
||||
for step, batch in enumerate(dataloader):
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
|
||||
loss = batch_output.loss / args.accumulation_steps
|
||||
total_loss += loss.item()
|
||||
|
||||
booster.backward(loss=loss, optimizer=optimizer)
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
all_reduce_mean(tensor=total_loss)
|
||||
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||
if coordinator.is_master():
|
||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
tag="Learning Rate",
|
||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||
global_step=global_step,
|
||||
)
|
||||
total_loss.fill_(0.0)
|
||||
pbar.update()
|
||||
# Save modeling.
|
||||
|
||||
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||
step + 1
|
||||
) == len(dataloader):
|
||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
save_checkpoint(
|
||||
save_dir=args.save_dir,
|
||||
booster=booster,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
epoch=epoch,
|
||||
step=step + 1,
|
||||
batch_size=args.micro_batch_size,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||
)
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Activate NEFTune.")
|
||||
model, handle = activate_neftune(model)
|
||||
|
||||
# Delete CUDA cache.
|
||||
# del batch, batch_labels, batch_output, loss
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||
dataloader.sampler.set_start_index(start_index=0)
|
||||
start_step = 0
|
||||
|
||||
if args.use_neft:
|
||||
coordinator.print_on_master("Deactivate NEFTune.")
|
||||
deactivate_neftune(model, handle)
|
||||
|
||||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -3,6 +3,8 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .huggingface import HuggingFaceModel
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
@ -126,9 +128,9 @@ class ChatGLMModel(HuggingFaceModel):
|
|||
"""
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
).to(torch.cuda.current_device())
|
||||
).to(get_current_device())
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||
torch.cuda.current_device()
|
||||
get_current_device()
|
||||
)
|
||||
|
||||
outputs = self.model(input_ids)[0]
|
||||
|
@ -197,7 +199,7 @@ class ChatGLM2Model(ChatGLMModel):
|
|||
truncation=True,
|
||||
return_tensors="pt",
|
||||
max_length=self.model_max_length - max_new_tokens,
|
||||
).to(torch.cuda.current_device())
|
||||
).to(get_current_device())
|
||||
|
||||
# Set output_scores=True to get prediction scores.
|
||||
outputs = self.model.generate(
|
||||
|
|
|
@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokeni
|
|||
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import BaseModel
|
||||
|
||||
|
@ -128,12 +129,12 @@ class HuggingFaceModel(BaseModel):
|
|||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
self.model.to(get_current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||
else:
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
self.model.eval()
|
||||
|
@ -155,11 +156,11 @@ class HuggingFaceModel(BaseModel):
|
|||
"""
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||
).to(torch.cuda.current_device())
|
||||
).to(get_current_device())
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||
torch.cuda.current_device()
|
||||
get_current_device()
|
||||
)
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device())
|
||||
|
||||
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
|
@ -464,7 +465,7 @@ class HuggingFaceModel(BaseModel):
|
|||
return_tensors="pt",
|
||||
return_token_type_ids=False,
|
||||
max_length=self.model_max_length - max_new_tokens,
|
||||
).to(torch.cuda.current_device())
|
||||
).to(get_current_device())
|
||||
|
||||
# Set output_scores=True to get prediction scores.
|
||||
outputs = self.model.generate(
|
||||
|
@ -598,12 +599,12 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
|||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||
shard_former = ShardFormer(shard_config)
|
||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||
self.model.to(torch.cuda.current_device())
|
||||
self.model.to(get_current_device())
|
||||
|
||||
if peft_path is not None:
|
||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device())
|
||||
if peft_path is not None:
|
||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.distributed as dist
|
|||
from colossal_eval import dataset, models, utils
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer import ShardConfig
|
||||
|
@ -82,6 +83,7 @@ def rm_and_merge(
|
|||
|
||||
def main(args):
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
accelerator = get_accelerator()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
@ -235,10 +237,10 @@ def main(args):
|
|||
),
|
||||
)
|
||||
|
||||
logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
|
||||
logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB")
|
||||
|
||||
del model_
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.empty_cache()
|
||||
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
|
|
|
@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
|
|||
self.world_size = dist.get_world_size()
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
|
|||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
|
|
|
@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
|
|||
return ["cuda", "npu"]
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
||||
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
||||
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||
sampler = DistributedSampler(
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset,
|
||||
num_replicas=zero_world_size * extra_dp_world_size,
|
||||
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
||||
|
|
|
@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return outputs
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
|
@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
|
@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
tp_group=self.tp_group,
|
||||
use_zero=self.use_zero,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
device=get_current_device(),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
|
@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero:
|
||||
v = v.cuda()
|
||||
v = v.to(get_current_device())
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
|
Loading…
Reference in New Issue