mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5377 from hpcaitech/example/llama-npu
[llama] support npu for Colossal-LLaMA-2pull/5380/head
commit
4c03347fc7
|
@ -1,20 +1,16 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
|
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import dataset_dict, load_from_disk
|
|
||||||
from datasets import Dataset as HFDataset
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from torch.distributed.distributed_c10d import _get_default_group
|
|
||||||
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from datasets import Dataset as HFDataset
|
||||||
|
from datasets import dataset_dict, load_from_disk
|
||||||
|
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||||
PathType = Union[str, os.PathLike]
|
PathType = Union[str, os.PathLike]
|
||||||
|
@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
|
||||||
tokenizer: PreTrainedTokenizer
|
tokenizer: PreTrainedTokenizer
|
||||||
max_length: int = 4096
|
max_length: int = 4096
|
||||||
ignore_index: int = -100
|
ignore_index: int = -100
|
||||||
|
padding: str = "max_length"
|
||||||
|
|
||||||
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
@ -106,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
padding_value=self.ignore_index,
|
padding_value=self.ignore_index,
|
||||||
) # (bsz, max_len)
|
) # (bsz, max_len)
|
||||||
# pad to max
|
if self.padding == "max_length":
|
||||||
to_pad = self.max_length - input_ids.size(1)
|
# pad to max
|
||||||
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
to_pad = self.max_length - input_ids.size(1)
|
||||||
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
|
||||||
|
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
|
||||||
elif self.tokenizer.padding_side == "left":
|
elif self.tokenizer.padding_side == "left":
|
||||||
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
|
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
|
||||||
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
|
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
@ -171,49 +169,3 @@ class StatefulDistributedSampler(DistributedSampler):
|
||||||
|
|
||||||
def set_start_index(self, start_index: int) -> None:
|
def set_start_index(self, start_index: int) -> None:
|
||||||
self.start_index = start_index
|
self.start_index = start_index
|
||||||
|
|
||||||
|
|
||||||
def setup_distributed_dataloader(
|
|
||||||
dataset: DatasetType,
|
|
||||||
batch_size: int = 1,
|
|
||||||
shuffle: bool = False,
|
|
||||||
seed: int = 1024,
|
|
||||||
drop_last: bool = False,
|
|
||||||
pin_memory: bool = False,
|
|
||||||
num_workers: int = 0,
|
|
||||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
|
||||||
process_group: Optional[ProcessGroup] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> DataLoader:
|
|
||||||
"""
|
|
||||||
Setup dataloader for distributed training.
|
|
||||||
"""
|
|
||||||
_kwargs = kwargs.copy()
|
|
||||||
process_group = process_group or _get_default_group()
|
|
||||||
sampler = StatefulDistributedSampler(
|
|
||||||
dataset=dataset,
|
|
||||||
num_replicas=process_group.size(),
|
|
||||||
rank=process_group.rank(),
|
|
||||||
shuffle=shuffle,
|
|
||||||
seed=seed,
|
|
||||||
drop_last=drop_last,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deterministic dataloader
|
|
||||||
def seed_worker(worker_id: int) -> None:
|
|
||||||
worker_seed = seed
|
|
||||||
np.random.seed(worker_seed)
|
|
||||||
torch.manual_seed(worker_seed)
|
|
||||||
random.seed(worker_seed)
|
|
||||||
|
|
||||||
return DataLoader(
|
|
||||||
dataset=dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
sampler=sampler,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
drop_last=drop_last,
|
|
||||||
worker_init_fn=seed_worker,
|
|
||||||
**_kwargs,
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import math
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
|
||||||
from flash_attn.ops.rms_norm import rms_norm
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
|
@ -19,194 +19,334 @@ from transformers.models.llama.modeling_llama import (
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
if get_accelerator().name == "cuda":
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||||
|
from flash_attn.ops.rms_norm import rms_norm
|
||||||
|
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
self: LlamaModel,
|
self: LlamaModel,
|
||||||
attention_mask: torch.BoolTensor,
|
attention_mask: torch.BoolTensor,
|
||||||
input_shape: torch.Size,
|
input_shape: torch.Size,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
past_key_values_length: int,
|
past_key_values_length: int,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Decoder attetion mask
|
Decoder attetion mask
|
||||||
"""
|
"""
|
||||||
if past_key_values_length > 0 and attention_mask is not None:
|
if past_key_values_length > 0 and attention_mask is not None:
|
||||||
attention_mask = torch.cat(
|
attention_mask = torch.cat(
|
||||||
tensors=(
|
tensors=(
|
||||||
torch.full(
|
torch.full(
|
||||||
size=(input_shape[0], past_key_values_length),
|
size=(input_shape[0], past_key_values_length),
|
||||||
fill_value=True,
|
fill_value=True,
|
||||||
dtype=attention_mask.dtype,
|
dtype=attention_mask.dtype,
|
||||||
device=attention_mask.device,
|
device=attention_mask.device,
|
||||||
|
),
|
||||||
|
attention_mask,
|
||||||
),
|
),
|
||||||
attention_mask,
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
) # (bsz, past_key_values_length + q_len)
|
|
||||||
if attention_mask is not None and torch.all(attention_mask):
|
|
||||||
return None # Faster
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
def attention_forward(
|
|
||||||
self: LlamaAttention,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""
|
|
||||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
|
||||||
"""
|
|
||||||
if output_attentions:
|
|
||||||
logger.warning(
|
|
||||||
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
|
||||||
"return `None` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
|
||||||
q_slicing, kv_slicing = (
|
|
||||||
dim // self.config.pretraining_tp
|
|
||||||
for dim in (
|
|
||||||
self.num_heads * self.head_dim,
|
|
||||||
self.num_key_value_heads * self.head_dim,
|
|
||||||
)
|
|
||||||
) # `Tuple[int, int]`
|
|
||||||
q_slices, k_slices, v_slices = (
|
|
||||||
proj.weight.split(slicing, dim=0)
|
|
||||||
for proj, slicing in (
|
|
||||||
(self.q_proj, q_slicing),
|
|
||||||
(self.k_proj, kv_slicing),
|
|
||||||
(self.v_proj, kv_slicing),
|
|
||||||
)
|
|
||||||
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
|
||||||
q, k, v = (
|
|
||||||
torch.cat(
|
|
||||||
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
|
||||||
dim=-1,
|
dim=-1,
|
||||||
|
) # (bsz, past_key_values_length + q_len)
|
||||||
|
if attention_mask is not None and torch.all(attention_mask):
|
||||||
|
return None # Faster
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
def attention_forward(
|
||||||
|
self: LlamaAttention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||||
|
"""
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning(
|
||||||
|
"Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
|
||||||
|
"return `None` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
q_slicing, kv_slicing = (
|
||||||
|
dim // self.config.pretraining_tp
|
||||||
|
for dim in (
|
||||||
|
self.num_heads * self.head_dim,
|
||||||
|
self.num_key_value_heads * self.head_dim,
|
||||||
|
)
|
||||||
|
) # `Tuple[int, int]`
|
||||||
|
q_slices, k_slices, v_slices = (
|
||||||
|
proj.weight.split(slicing, dim=0)
|
||||||
|
for proj, slicing in (
|
||||||
|
(self.q_proj, q_slicing),
|
||||||
|
(self.k_proj, kv_slicing),
|
||||||
|
(self.v_proj, kv_slicing),
|
||||||
|
)
|
||||||
|
) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
|
||||||
|
q, k, v = (
|
||||||
|
torch.cat(
|
||||||
|
[F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
for slices in (q_slices, k_slices, v_slices)
|
||||||
|
)
|
||||||
|
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||||
|
# (bsz, q_len, num_heads * head_dim),
|
||||||
|
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||||
|
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||||
|
else:
|
||||||
|
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
||||||
|
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
||||||
|
# (bsz, q_len, num_heads * head_dim),
|
||||||
|
# (bsz, q_len, num_key_value_heads * head_dim),
|
||||||
|
# (bsz, q_len, num_key_value_heads * head_dim)
|
||||||
|
|
||||||
|
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
||||||
|
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
||||||
|
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
||||||
|
q, k, v = (
|
||||||
|
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
for states, num_heads in (
|
||||||
|
(q, self.num_heads),
|
||||||
|
(k, self.num_key_value_heads),
|
||||||
|
(v, self.num_key_value_heads),
|
||||||
)
|
)
|
||||||
for slices in (q_slices, k_slices, v_slices)
|
|
||||||
)
|
)
|
||||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
||||||
# (bsz, q_len, num_heads * head_dim),
|
past_kv_len = 0
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
if past_key_value is not None:
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
||||||
else:
|
past_kv_len = past_key_value[0].shape[-2]
|
||||||
q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
|
kv_len += past_kv_len
|
||||||
# `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
|
|
||||||
# (bsz, q_len, num_heads * head_dim),
|
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim),
|
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim)
|
|
||||||
|
|
||||||
# (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
|
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
|
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
||||||
# (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
|
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
||||||
q, k, v = (
|
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
||||||
states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
|
if past_key_value is not None:
|
||||||
for states, num_heads in (
|
# reuse k, v, self_attention
|
||||||
(q, self.num_heads),
|
k = torch.cat([past_key_value[0], k], dim=2)
|
||||||
(k, self.num_key_value_heads),
|
v = torch.cat([past_key_value[1], v], dim=2)
|
||||||
(v, self.num_key_value_heads),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
|
|
||||||
past_kv_len = 0
|
|
||||||
if past_key_value is not None:
|
|
||||||
# if `past_key_value` is not None, `kv_len` > `q_len`.
|
|
||||||
past_kv_len = past_key_value[0].shape[-2]
|
|
||||||
kv_len += past_kv_len
|
|
||||||
|
|
||||||
# two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
|
past_key_value = (k, v) if use_cache else None
|
||||||
cos, sin = self.rotary_emb(v, seq_len=kv_len)
|
|
||||||
# (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
|
|
||||||
q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
k = torch.cat([past_key_value[0], k], dim=2)
|
|
||||||
v = torch.cat([past_key_value[1], v], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (k, v) if use_cache else None
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
||||||
|
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||||
|
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
||||||
|
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
key_padding_mask = attention_mask
|
||||||
k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
|
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
||||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
||||||
v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
|
|
||||||
# (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
|
|
||||||
|
|
||||||
key_padding_mask = attention_mask
|
if past_kv_len > 0:
|
||||||
# (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
|
q = torch.cat(
|
||||||
q, k, v = (states.transpose(1, 2) for states in (q, k, v))
|
tensors=(
|
||||||
|
torch.full(
|
||||||
if past_kv_len > 0:
|
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
||||||
q = torch.cat(
|
fill_value=0.0,
|
||||||
tensors=(
|
dtype=q.dtype,
|
||||||
torch.full(
|
device=q.device,
|
||||||
size=(bsz, past_kv_len, self.num_heads, self.head_dim),
|
),
|
||||||
fill_value=0.0,
|
q,
|
||||||
dtype=q.dtype,
|
|
||||||
device=q.device,
|
|
||||||
),
|
),
|
||||||
q,
|
dim=1,
|
||||||
),
|
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||||
dim=1,
|
|
||||||
) # (bsz, past_kv_len + q_len, num_heads, head_dim)
|
|
||||||
|
|
||||||
if key_padding_mask is None:
|
if key_padding_mask is None:
|
||||||
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
# (bsz, past_kv_len + q_len, num_heads, head_dim)
|
||||||
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
|
||||||
output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
output = rearrange(
|
||||||
else:
|
output, pattern="... h d -> ... (h d)"
|
||||||
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||||
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
else:
|
||||||
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
|
||||||
attention_mask=key_padding_mask,
|
kv, _, cu_kv_lens, max_kv_len = unpad_input(
|
||||||
)
|
hidden_states=torch.stack(tensors=(k, v), dim=2),
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
attention_mask=key_padding_mask,
|
||||||
q=q,
|
)
|
||||||
kv=kv,
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
cu_seqlens_q=cu_q_lens,
|
q=q,
|
||||||
cu_seqlens_k=cu_kv_lens,
|
kv=kv,
|
||||||
max_seqlen_q=max_q_len,
|
cu_seqlens_q=cu_q_lens,
|
||||||
max_seqlen_k=max_kv_len,
|
cu_seqlens_k=cu_kv_lens,
|
||||||
dropout_p=0.0,
|
max_seqlen_q=max_q_len,
|
||||||
softmax_scale=None,
|
max_seqlen_k=max_kv_len,
|
||||||
causal=True,
|
dropout_p=0.0,
|
||||||
)
|
softmax_scale=None,
|
||||||
output = pad_input(
|
causal=True,
|
||||||
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
)
|
||||||
indices=indices,
|
output = pad_input(
|
||||||
batch=bsz,
|
hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
|
||||||
seqlen=past_kv_len + q_len,
|
indices=indices,
|
||||||
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
batch=bsz,
|
||||||
|
seqlen=past_kv_len + q_len,
|
||||||
|
) # (bsz, past_kv_len + q_len, num_heads * head_dim)
|
||||||
|
|
||||||
if past_kv_len > 0:
|
if past_kv_len > 0:
|
||||||
# Strip off the zero query outputs.
|
# Strip off the zero query outputs.
|
||||||
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
|
||||||
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
output = self.o_proj(output) # (bsz, q_len, hidden_size)
|
||||||
return output, None, past_key_value
|
return output, None, past_key_value
|
||||||
|
|
||||||
|
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Formard function for RMS Norm
|
||||||
|
"""
|
||||||
|
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
||||||
|
|
||||||
def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
|
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||||
"""
|
for name, module in model.named_modules():
|
||||||
Formard function for RMS Norm
|
if isinstance(module, LlamaAttention):
|
||||||
"""
|
module.forward = MethodType(attention_forward, module)
|
||||||
return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
|
if isinstance(module, LlamaModel):
|
||||||
|
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
||||||
|
if isinstance(module, LlamaRMSNorm):
|
||||||
|
module.forward = MethodType(rms_norm_forward, module)
|
||||||
|
|
||||||
|
elif get_accelerator().name == "npu":
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
class NPULlamaAttention(LlamaAttention):
|
||||||
for name, module in model.named_modules():
|
use_flash: bool = True
|
||||||
if isinstance(module, LlamaAttention):
|
|
||||||
module.forward = MethodType(attention_forward, module)
|
def __init__(self, config: LlamaConfig):
|
||||||
if isinstance(module, LlamaModel):
|
super().__init__(config)
|
||||||
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
|
self.setup()
|
||||||
if isinstance(module, LlamaRMSNorm):
|
|
||||||
module.forward = MethodType(rms_norm_forward, module)
|
def setup(self):
|
||||||
|
self._softmax_scale = 1 / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if not self.use_flash:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
else:
|
||||||
|
attn_output, *_ = torch_npu.npu_fusion_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
self.num_heads,
|
||||||
|
"BNSD",
|
||||||
|
atten_mask=attention_mask.bool(),
|
||||||
|
scale=self._softmax_scale,
|
||||||
|
padding_mask=None,
|
||||||
|
pre_tockens=65535,
|
||||||
|
next_tockens=0,
|
||||||
|
keep_prob=1.0,
|
||||||
|
inner_precise=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||||
|
attn_output = sum(
|
||||||
|
[F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
class NPURMSNorm(LlamaRMSNorm):
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||||
|
|
||||||
|
def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LlamaAttention):
|
||||||
|
module.__class__ = NPULlamaAttention
|
||||||
|
module.setup()
|
||||||
|
if isinstance(module, LlamaRMSNorm):
|
||||||
|
module.__class__ = NPURMSNorm
|
||||||
|
|
|
@ -17,7 +17,7 @@ import torch
|
||||||
|
|
||||||
def unwrap(model):
|
def unwrap(model):
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
return unwrap_model(model.module)
|
return model.unwrap()
|
||||||
else:
|
else:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.
|
||||||
--warmup_steps 100 \
|
--warmup_steps 100 \
|
||||||
--use_grad_checkpoint \
|
--use_grad_checkpoint \
|
||||||
--use_flash_attn \
|
--use_flash_attn \
|
||||||
|
--pad_token "unk"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -16,22 +16,24 @@ from colossal_llama2.dataset.loader import (
|
||||||
DataCollatorForSupervisedDataset,
|
DataCollatorForSupervisedDataset,
|
||||||
StatefulDistributedSampler,
|
StatefulDistributedSampler,
|
||||||
load_tokenized_dataset,
|
load_tokenized_dataset,
|
||||||
setup_distributed_dataloader,
|
|
||||||
)
|
)
|
||||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||||
|
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: torch.nn.Module) -> int:
|
def get_model_numel(model: torch.nn.Module) -> int:
|
||||||
|
@ -83,6 +85,7 @@ def main() -> None:
|
||||||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
||||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
||||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
||||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||||
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
||||||
|
@ -108,6 +111,12 @@ def main() -> None:
|
||||||
default=False,
|
default=False,
|
||||||
help="Use flash-attention",
|
help="Use flash-attention",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_neft",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use NEFTune",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--freeze_non_embeds_params",
|
"--freeze_non_embeds_params",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
@ -116,6 +125,8 @@ def main() -> None:
|
||||||
)
|
)
|
||||||
parser.add_argument("--tp", type=int, default=1)
|
parser.add_argument("--tp", type=int, default=1)
|
||||||
parser.add_argument("--zero", type=int, default=1)
|
parser.add_argument("--zero", type=int, default=1)
|
||||||
|
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
||||||
|
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with open(args.config_file, "w") as f:
|
with open(args.config_file, "w") as f:
|
||||||
|
@ -125,6 +136,7 @@ def main() -> None:
|
||||||
# Initialize Distributed Training
|
# Initialize Distributed Training
|
||||||
# ==============================
|
# ==============================
|
||||||
colossalai.launch_from_torch({})
|
colossalai.launch_from_torch({})
|
||||||
|
accelerator = get_accelerator()
|
||||||
coordinator = DistCoordinator()
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
|
@ -182,7 +194,10 @@ def main() -> None:
|
||||||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||||
# ======================================================
|
# ======================================================
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
||||||
tokenizer.pad_token = tokenizer.unk_token
|
if args.pad_token == "eos":
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.pad_token == "unk":
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
tokenizer.add_bos_token = False
|
tokenizer.add_bos_token = False
|
||||||
tokenizer.add_eos_token = False
|
tokenizer.add_eos_token = False
|
||||||
|
|
||||||
|
@ -193,38 +208,36 @@ def main() -> None:
|
||||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||||
|
|
||||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
||||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
data_collator = DataCollatorForSupervisedDataset(
|
||||||
dataloader = setup_distributed_dataloader(
|
tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
|
||||||
|
)
|
||||||
|
dataloader = plugin.prepare_dataloader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=args.micro_batch_size,
|
batch_size=args.micro_batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
collate_fn=data_collator,
|
collate_fn=data_collator,
|
||||||
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ======================================================
|
# ======================================================
|
||||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||||
# ======================================================
|
# ======================================================
|
||||||
|
init_ctx = (
|
||||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
LazyInitContext(default_device=get_current_device())
|
||||||
try:
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||||
from colossalai.accelerator import get_accelerator
|
else nullcontext()
|
||||||
|
)
|
||||||
current_device = get_accelerator().get_current_device()
|
|
||||||
except:
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
|
|
||||||
current_device = get_current_device()
|
|
||||||
|
|
||||||
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
||||||
# Freeze part of parameters.
|
# Freeze part of parameters.
|
||||||
if args.freeze_non_embeds_params:
|
if args.freeze_non_embeds_params:
|
||||||
freeze_non_embeds_parameters(model=model)
|
freeze_non_embeds_parameters(model=model)
|
||||||
|
# this is essential, otherwise the grad checkpoint will not work.
|
||||||
|
model.train()
|
||||||
|
|
||||||
if args.use_grad_checkpoint:
|
if args.use_grad_checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
@ -246,12 +259,14 @@ def main() -> None:
|
||||||
adamw_mode=True,
|
adamw_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.warmup_steps is None:
|
||||||
|
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
||||||
|
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||||
|
|
||||||
lr_scheduler = CosineAnnealingWarmupLR(
|
lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
total_steps=args.num_epochs * len(dataloader),
|
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
||||||
warmup_steps=args.warmup_steps
|
warmup_steps=args.warmup_steps,
|
||||||
if args.warmup_steps is not None
|
|
||||||
else int(args.num_epochs * len(dataloader) * 0.025),
|
|
||||||
eta_min=0.1 * args.lr,
|
eta_min=0.1 * args.lr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -267,11 +282,9 @@ def main() -> None:
|
||||||
|
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
|
|
||||||
if args.load_checkpoint is None:
|
coordinator.print_on_master(
|
||||||
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
|
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
booster.load_model(model, args.pretrained, strict=False)
|
)
|
||||||
|
|
||||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||||
)
|
)
|
||||||
|
@ -298,85 +311,109 @@ def main() -> None:
|
||||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||||
|
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||||
)
|
)
|
||||||
|
|
||||||
num_steps_per_epoch = len(dataloader)
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Activate NEFTune.")
|
||||||
|
model, handle = activate_neftune(model)
|
||||||
|
|
||||||
|
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
||||||
# If resume training, set the sampler start index to the correct value
|
# If resume training, set the sampler start index to the correct value
|
||||||
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
||||||
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.num_epochs):
|
for epoch in range(start_epoch, args.num_epochs):
|
||||||
dataloader.sampler.set_epoch(epoch=epoch)
|
dataloader.sampler.set_epoch(epoch=epoch)
|
||||||
with tqdm(
|
pbar = tqdm(
|
||||||
iterable=enumerate(dataloader, start=start_step),
|
|
||||||
desc=f"Epoch {epoch}",
|
desc=f"Epoch {epoch}",
|
||||||
disable=not coordinator.is_master(),
|
disable=not coordinator.is_master(),
|
||||||
total=num_steps_per_epoch,
|
total=num_steps_per_epoch,
|
||||||
initial=start_step,
|
initial=start_step // args.accumulation_steps,
|
||||||
) as pbar:
|
)
|
||||||
for step, batch in pbar:
|
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||||
batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
for step, batch in enumerate(dataloader, start=start_step):
|
||||||
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||||
|
|
||||||
batch_output = model(**batch)
|
batch_output = model(**batch)
|
||||||
|
|
||||||
loss = batch_output.loss
|
loss = batch_output.loss / args.accumulation_steps
|
||||||
|
total_loss.add_(loss.data)
|
||||||
|
|
||||||
booster.backward(loss=loss, optimizer=optimizer)
|
booster.backward(loss=loss, optimizer=optimizer)
|
||||||
|
|
||||||
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
all_reduce_mean(tensor=loss)
|
all_reduce_mean(tensor=total_loss)
|
||||||
pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
|
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
||||||
if coordinator.is_master():
|
if coordinator.is_master():
|
||||||
global_step = epoch * num_steps_per_epoch + step
|
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||||
writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
|
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
tag="Learning Rate",
|
tag="Learning Rate",
|
||||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
)
|
)
|
||||||
# Save modeling.
|
total_loss.fill_(0.0)
|
||||||
|
pbar.update()
|
||||||
|
# Save modeling.
|
||||||
|
|
||||||
if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
|
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
||||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
step + 1
|
||||||
save_checkpoint(
|
) == len(dataloader):
|
||||||
save_dir=args.save_dir,
|
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||||
booster=booster,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
epoch=epoch,
|
|
||||||
step=step + 1,
|
|
||||||
batch_size=args.micro_batch_size,
|
|
||||||
coordinator=coordinator,
|
|
||||||
)
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete CUDA cache.
|
if args.use_neft:
|
||||||
# del batch, batch_labels, batch_output, loss
|
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||||
torch.cuda.empty_cache()
|
deactivate_neftune(model, handle)
|
||||||
|
|
||||||
|
accelerator.empty_cache()
|
||||||
|
save_checkpoint(
|
||||||
|
save_dir=args.save_dir,
|
||||||
|
booster=booster,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
epoch=epoch,
|
||||||
|
step=step + 1,
|
||||||
|
batch_size=args.micro_batch_size,
|
||||||
|
coordinator=coordinator,
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Activate NEFTune.")
|
||||||
|
model, handle = activate_neftune(model)
|
||||||
|
|
||||||
|
# Delete cache.
|
||||||
|
# del batch, batch_labels, batch_output, loss
|
||||||
|
accelerator.empty_cache()
|
||||||
|
|
||||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||||
dataloader.sampler.set_start_index(start_index=0)
|
dataloader.sampler.set_start_index(start_index=0)
|
||||||
start_step = 0
|
start_step = 0
|
||||||
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Deactivate NEFTune.")
|
||||||
|
deactivate_neftune(model, handle)
|
||||||
|
|
||||||
# Final save.
|
# Final save.
|
||||||
coordinator.print_on_master("Start saving final model checkpoint")
|
coordinator.print_on_master("Start saving final model checkpoint")
|
||||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||||
|
|
||||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||||
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
|
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
|
||||||
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||||
|
|
||||||
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \
|
colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
|
||||||
--pretrained $PRETRAINED_MODEL_PATH \
|
--pretrained $PRETRAINED_MODEL_PATH \
|
||||||
--dataset ${dataset[@]} \
|
--dataset ${dataset[@]} \
|
||||||
--plugin "zero2" \
|
--plugin "zero2" \
|
||||||
|
@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_
|
||||||
--use_grad_checkpoint \
|
--use_grad_checkpoint \
|
||||||
--use_flash_attn \
|
--use_flash_attn \
|
||||||
--use_neft \
|
--use_neft \
|
||||||
|
--pad_token "eos"
|
||||||
|
|
|
@ -1,403 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import resource
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from colossal_llama2.dataset.loader import (
|
|
||||||
DataCollatorForSupervisedDataset,
|
|
||||||
StatefulDistributedSampler,
|
|
||||||
load_tokenized_dataset,
|
|
||||||
setup_distributed_dataloader,
|
|
||||||
)
|
|
||||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
|
||||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
|
||||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
|
||||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.booster import Booster
|
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
|
||||||
from colossalai.cluster import DistCoordinator
|
|
||||||
from colossalai.lazy import LazyInitContext
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: torch.nn.Module) -> int:
|
|
||||||
return sum(p.numel() for p in model.parameters())
|
|
||||||
|
|
||||||
|
|
||||||
def format_numel_str(numel: int) -> str:
|
|
||||||
B = 1024**3
|
|
||||||
M = 1024**2
|
|
||||||
K = 1024
|
|
||||||
if numel >= B:
|
|
||||||
return f"{numel / B:.2f} B"
|
|
||||||
elif numel >= M:
|
|
||||||
return f"{numel / M:.2f} M"
|
|
||||||
elif numel >= K:
|
|
||||||
return f"{numel / K:.2f} K"
|
|
||||||
else:
|
|
||||||
return f"{numel}"
|
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
|
||||||
tensor.div_(dist.get_world_size())
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
# ==============================
|
|
||||||
# Parse Arguments
|
|
||||||
# ==============================
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Address of the pre-trained modeling",
|
|
||||||
)
|
|
||||||
parser.add_argument("--dataset", nargs="+", default=[])
|
|
||||||
parser.add_argument(
|
|
||||||
"--plugin",
|
|
||||||
type=str,
|
|
||||||
default="gemini",
|
|
||||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
|
||||||
help="Choose which plugin to use",
|
|
||||||
)
|
|
||||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
|
||||||
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
|
||||||
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
|
||||||
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
|
||||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps")
|
|
||||||
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
|
||||||
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
|
||||||
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
|
|
||||||
parser.add_argument(
|
|
||||||
"--mixed_precision",
|
|
||||||
type=str,
|
|
||||||
default="fp16",
|
|
||||||
choices=["fp16", "bf16"],
|
|
||||||
help="Mixed precision",
|
|
||||||
)
|
|
||||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
|
||||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
|
||||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_grad_checkpoint",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use gradient checkpointing",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_flash_attn",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use flash-attention",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_neft",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use NEFTune",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--freeze_non_embeds_params",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Freeze non embeddings parameters",
|
|
||||||
)
|
|
||||||
parser.add_argument("--tp", type=int, default=1)
|
|
||||||
parser.add_argument("--zero", type=int, default=1)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
with open(args.config_file, "w") as f:
|
|
||||||
json.dump(args.__dict__, f, indent=4)
|
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# Initialize Distributed Training
|
|
||||||
# ==============================
|
|
||||||
colossalai.launch_from_torch({})
|
|
||||||
coordinator = DistCoordinator()
|
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# Initialize Tensorboard
|
|
||||||
# ==============================
|
|
||||||
if coordinator.is_master():
|
|
||||||
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
|
||||||
writer = SummaryWriter(args.tensorboard_dir)
|
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# Initialize Booster
|
|
||||||
# ==============================
|
|
||||||
if args.plugin == "gemini":
|
|
||||||
plugin = GeminiPlugin(
|
|
||||||
precision=args.mixed_precision,
|
|
||||||
initial_scale=2**16,
|
|
||||||
max_norm=args.grad_clip,
|
|
||||||
)
|
|
||||||
elif args.plugin == "gemini_auto":
|
|
||||||
plugin = GeminiPlugin(
|
|
||||||
precision=args.mixed_precision,
|
|
||||||
placement_policy="auto",
|
|
||||||
initial_scale=2**16,
|
|
||||||
max_norm=args.grad_clip,
|
|
||||||
)
|
|
||||||
elif args.plugin == "zero2":
|
|
||||||
plugin = LowLevelZeroPlugin(
|
|
||||||
stage=2,
|
|
||||||
precision=args.mixed_precision,
|
|
||||||
initial_scale=2**16,
|
|
||||||
max_norm=args.grad_clip,
|
|
||||||
)
|
|
||||||
elif args.plugin == "zero2_cpu":
|
|
||||||
plugin = LowLevelZeroPlugin(
|
|
||||||
stage=2,
|
|
||||||
precision=args.mixed_precision,
|
|
||||||
initial_scale=2**16,
|
|
||||||
cpu_offload=True,
|
|
||||||
max_norm=args.grad_clip,
|
|
||||||
)
|
|
||||||
elif args.plugin == "3d":
|
|
||||||
plugin = HybridParallelPlugin(
|
|
||||||
tp_size=args.tp,
|
|
||||||
pp_size=1,
|
|
||||||
zero_stage=args.zero,
|
|
||||||
max_norm=args.grad_clip,
|
|
||||||
precision=args.mixed_precision,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
|
||||||
|
|
||||||
booster = Booster(plugin=plugin)
|
|
||||||
|
|
||||||
# ======================================================
|
|
||||||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
|
||||||
# ======================================================
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
tokenizer.add_bos_token = False
|
|
||||||
tokenizer.add_eos_token = False
|
|
||||||
|
|
||||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
|
||||||
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
|
|
||||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
|
|
||||||
|
|
||||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
|
||||||
|
|
||||||
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
|
||||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
|
||||||
dataloader = setup_distributed_dataloader(
|
|
||||||
dataset=dataset,
|
|
||||||
batch_size=args.micro_batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=True,
|
|
||||||
collate_fn=data_collator,
|
|
||||||
)
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ======================================================
|
|
||||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
|
||||||
# ======================================================
|
|
||||||
init_ctx = (
|
|
||||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
|
||||||
)
|
|
||||||
with init_ctx:
|
|
||||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
|
||||||
# Freeze part of parameters.
|
|
||||||
if args.freeze_non_embeds_params:
|
|
||||||
freeze_non_embeds_parameters(model=model)
|
|
||||||
|
|
||||||
if args.use_grad_checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
|
||||||
if args.use_flash_attn:
|
|
||||||
replace_with_flash_attention(model=model)
|
|
||||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
|
||||||
|
|
||||||
model_numel = get_model_numel(model)
|
|
||||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
|
||||||
|
|
||||||
optimizer = HybridAdam(
|
|
||||||
model_params=filter(lambda p: p.requires_grad, model.parameters())
|
|
||||||
if args.freeze_non_embeds_params
|
|
||||||
else model.parameters(),
|
|
||||||
lr=args.lr,
|
|
||||||
betas=(0.9, 0.95),
|
|
||||||
weight_decay=args.weight_decay,
|
|
||||||
adamw_mode=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.warmup_steps is None:
|
|
||||||
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
|
||||||
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
|
||||||
|
|
||||||
lr_scheduler = CosineAnnealingWarmupLR(
|
|
||||||
optimizer=optimizer,
|
|
||||||
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
|
||||||
warmup_steps=args.warmup_steps,
|
|
||||||
eta_min=0.1 * args.lr,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flash attention will be disabled because it does NOT support fp32.
|
|
||||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
|
||||||
torch.set_default_dtype(default_dtype)
|
|
||||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
dataloader=dataloader,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.set_default_dtype(torch.float)
|
|
||||||
|
|
||||||
if args.load_checkpoint is None:
|
|
||||||
coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
|
|
||||||
booster.load_model(model, args.pretrained, strict=False)
|
|
||||||
|
|
||||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
|
||||||
)
|
|
||||||
|
|
||||||
start_epoch = 0
|
|
||||||
start_step = 0
|
|
||||||
sampler_start_idx = 0
|
|
||||||
if args.load_checkpoint is not None:
|
|
||||||
if "modeling" in args.load_checkpoint:
|
|
||||||
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
|
|
||||||
booster.load_model(model, args.load_checkpoint)
|
|
||||||
else:
|
|
||||||
coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
|
|
||||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
|
||||||
load_dir=args.load_checkpoint,
|
|
||||||
booster=booster,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
)
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
|
|
||||||
)
|
|
||||||
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
|
||||||
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
|
||||||
)
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
|
||||||
)
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.use_neft:
|
|
||||||
coordinator.print_on_master("Activate NEFTune.")
|
|
||||||
model, handle = activate_neftune(model)
|
|
||||||
|
|
||||||
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
|
||||||
# If resume training, set the sampler start index to the correct value
|
|
||||||
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
|
||||||
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.num_epochs):
|
|
||||||
dataloader.sampler.set_epoch(epoch=epoch)
|
|
||||||
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
|
|
||||||
total_loss = torch.tensor(0.0).to(torch.cuda.current_device())
|
|
||||||
for step, batch in enumerate(dataloader):
|
|
||||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
|
||||||
|
|
||||||
batch_output = model(**batch)
|
|
||||||
|
|
||||||
loss = batch_output.loss / args.accumulation_steps
|
|
||||||
total_loss += loss.item()
|
|
||||||
|
|
||||||
booster.backward(loss=loss, optimizer=optimizer)
|
|
||||||
|
|
||||||
if (step + 1) % args.accumulation_steps == 0:
|
|
||||||
optimizer.step()
|
|
||||||
lr_scheduler.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
all_reduce_mean(tensor=total_loss)
|
|
||||||
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
|
||||||
if coordinator.is_master():
|
|
||||||
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
|
||||||
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
tag="Learning Rate",
|
|
||||||
scalar_value=lr_scheduler.get_last_lr()[0],
|
|
||||||
global_step=global_step,
|
|
||||||
)
|
|
||||||
total_loss.fill_(0.0)
|
|
||||||
pbar.update()
|
|
||||||
# Save modeling.
|
|
||||||
|
|
||||||
if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
|
|
||||||
step + 1
|
|
||||||
) == len(dataloader):
|
|
||||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
|
||||||
|
|
||||||
if args.use_neft:
|
|
||||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
|
||||||
deactivate_neftune(model, handle)
|
|
||||||
|
|
||||||
save_checkpoint(
|
|
||||||
save_dir=args.save_dir,
|
|
||||||
booster=booster,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
epoch=epoch,
|
|
||||||
step=step + 1,
|
|
||||||
batch_size=args.micro_batch_size,
|
|
||||||
coordinator=coordinator,
|
|
||||||
)
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.use_neft:
|
|
||||||
coordinator.print_on_master("Activate NEFTune.")
|
|
||||||
model, handle = activate_neftune(model)
|
|
||||||
|
|
||||||
# Delete CUDA cache.
|
|
||||||
# del batch, batch_labels, batch_output, loss
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
|
||||||
dataloader.sampler.set_start_index(start_index=0)
|
|
||||||
start_step = 0
|
|
||||||
|
|
||||||
if args.use_neft:
|
|
||||||
coordinator.print_on_master("Deactivate NEFTune.")
|
|
||||||
deactivate_neftune(model, handle)
|
|
||||||
|
|
||||||
# Final save.
|
|
||||||
coordinator.print_on_master("Start saving final model checkpoint")
|
|
||||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
|
||||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
|
||||||
|
|
||||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -3,6 +3,8 @@ from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .huggingface import HuggingFaceModel
|
from .huggingface import HuggingFaceModel
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
@ -126,9 +128,9 @@ class ChatGLMModel(HuggingFaceModel):
|
||||||
"""
|
"""
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||||
).to(torch.cuda.current_device())
|
).to(get_current_device())
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||||
torch.cuda.current_device()
|
get_current_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.model(input_ids)[0]
|
outputs = self.model(input_ids)[0]
|
||||||
|
@ -197,7 +199,7 @@ class ChatGLM2Model(ChatGLMModel):
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
max_length=self.model_max_length - max_new_tokens,
|
max_length=self.model_max_length - max_new_tokens,
|
||||||
).to(torch.cuda.current_device())
|
).to(get_current_device())
|
||||||
|
|
||||||
# Set output_scores=True to get prediction scores.
|
# Set output_scores=True to get prediction scores.
|
||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
|
|
|
@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokeni
|
||||||
|
|
||||||
from colossalai.logging import DistributedLogger
|
from colossalai.logging import DistributedLogger
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .base import BaseModel
|
from .base import BaseModel
|
||||||
|
|
||||||
|
@ -128,12 +129,12 @@ class HuggingFaceModel(BaseModel):
|
||||||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||||
shard_former = ShardFormer(shard_config)
|
shard_former = ShardFormer(shard_config)
|
||||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||||
self.model.to(torch.cuda.current_device())
|
self.model.to(get_current_device())
|
||||||
|
|
||||||
if peft_path is not None:
|
if peft_path is not None:
|
||||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||||
else:
|
else:
|
||||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device())
|
||||||
if peft_path is not None:
|
if peft_path is not None:
|
||||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -155,11 +156,11 @@ class HuggingFaceModel(BaseModel):
|
||||||
"""
|
"""
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||||
).to(torch.cuda.current_device())
|
).to(get_current_device())
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||||
torch.cuda.current_device()
|
get_current_device()
|
||||||
)
|
)
|
||||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device())
|
||||||
|
|
||||||
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
|
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
|
||||||
|
|
||||||
|
@ -464,7 +465,7 @@ class HuggingFaceModel(BaseModel):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
max_length=self.model_max_length - max_new_tokens,
|
max_length=self.model_max_length - max_new_tokens,
|
||||||
).to(torch.cuda.current_device())
|
).to(get_current_device())
|
||||||
|
|
||||||
# Set output_scores=True to get prediction scores.
|
# Set output_scores=True to get prediction scores.
|
||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
|
@ -598,12 +599,12 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||||
shard_former = ShardFormer(shard_config)
|
shard_former = ShardFormer(shard_config)
|
||||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||||
self.model.to(torch.cuda.current_device())
|
self.model.to(get_current_device())
|
||||||
|
|
||||||
if peft_path is not None:
|
if peft_path is not None:
|
||||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||||
else:
|
else:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device())
|
||||||
if peft_path is not None:
|
if peft_path is not None:
|
||||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import torch.distributed as dist
|
||||||
from colossal_eval import dataset, models, utils
|
from colossal_eval import dataset, models, utils
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
|
@ -82,6 +83,7 @@ def rm_and_merge(
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
colossalai.launch_from_torch(config={}, seed=42)
|
colossalai.launch_from_torch(config={}, seed=42)
|
||||||
|
accelerator = get_accelerator()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
@ -235,10 +237,10 @@ def main(args):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
|
logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB")
|
||||||
|
|
||||||
del model_
|
del model_
|
||||||
torch.cuda.empty_cache()
|
accelerator.empty_cache()
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
|
|
@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
|
|
||||||
def prepare_dataloader(
|
def prepare_dataloader(
|
||||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
distributed_sampler_cls=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
|
||||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||||
"""
|
"""
|
||||||
_kwargs = kwargs.copy()
|
_kwargs = kwargs.copy()
|
||||||
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||||
|
sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||||
|
|
||||||
# Deterministic dataloader
|
# Deterministic dataloader
|
||||||
def seed_worker(worker_id):
|
def seed_worker(worker_id):
|
||||||
|
|
|
@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
|
||||||
return ["cuda", "npu"]
|
return ["cuda", "npu"]
|
||||||
|
|
||||||
def prepare_dataloader(
|
def prepare_dataloader(
|
||||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
distributed_sampler_cls=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
|
||||||
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
||||||
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
||||||
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||||
sampler = DistributedSampler(
|
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||||
|
sampler = distributed_sampler_cls(
|
||||||
dataset,
|
dataset,
|
||||||
num_replicas=zero_world_size * extra_dp_world_size,
|
num_replicas=zero_world_size * extra_dp_world_size,
|
||||||
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
||||||
|
|
|
@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def prepare_dataloader(
|
def prepare_dataloader(
|
||||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
distributed_sampler_cls=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||||
"""
|
"""
|
||||||
_kwargs = kwargs.copy()
|
_kwargs = kwargs.copy()
|
||||||
sampler = DistributedSampler(
|
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||||
|
sampler = distributed_sampler_cls(
|
||||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
|
@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
use_zero=self.use_zero,
|
use_zero=self.use_zero,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
device=torch.device("cuda"),
|
device=get_current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pp_size == 1:
|
if self.pp_size == 1:
|
||||||
|
@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
# First gather Zero shards.
|
# First gather Zero shards.
|
||||||
if use_zero:
|
if use_zero:
|
||||||
v = v.cuda()
|
v = v.to(get_current_device())
|
||||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||||
|
|
Loading…
Reference in New Issue