pull/5190/head
Xuanlei Zhao 2023-12-25 16:29:47 +08:00
parent aa2e091dc6
commit 23341687ed
5 changed files with 40 additions and 113 deletions

View File

@ -2,52 +2,14 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
import torch.distributed as dist
import torch.nn as nn
import copy
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import OptimizerWrapper
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import (
get_dp_group,
get_dp_rank,
get_dp_size,
get_ep_group,
get_ep_rank,
get_ep_size,
is_moe_tensor,
)
from colossalai.checkpoint_io import CheckpointIndexFile from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.moe import MoECheckpintIO from colossalai.moe import MoECheckpintIO
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, is_moe_tensor from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
class MixtralMoECheckpointIO(MoECheckpintIO): class MixtralMoECheckpointIO(MoECheckpintIO):
@ -62,8 +24,8 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
model_param_dict = dict(model.named_parameters()) model_param_dict = dict(model.named_parameters())
for name, param in list(state_dict.items()): for name, param in list(state_dict.items()):
if ".gate.weight" in name: if ".gate.weight" in name:
new_name = "module." + name.replace(".gate.weight", ".gate_weight") new_name = "module." + name.replace(".gate.weight", ".gate_weight")
state_dict[new_name] = state_dict.pop(name) state_dict[new_name] = state_dict.pop(name)
elif ".experts." in name: elif ".experts." in name:
# if is moe tensor # if is moe tensor
# in our moe module, expert is cat as one tensor # in our moe module, expert is cat as one tensor
@ -94,7 +56,7 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
state_dict[model_param_name] = new_param state_dict[model_param_name] = new_param
state_dict.pop(name) state_dict.pop(name)
else: else:
new_name = "module." + name new_name = "module." + name
state_dict[new_name] = state_dict.pop(name) state_dict[new_name] = state_dict.pop(name)
for name, param in list(state_dict.items()): for name, param in list(state_dict.items()):

View File

@ -1,10 +1,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralDecoderLayer from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe import SparseMLP from colossalai.moe import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_rank, is_moe_tensor
class MixtralSparseMLP: class MixtralSparseMLP:

View File

@ -1,12 +1,12 @@
import argparse import argparse
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.models.mixtral_layer import replace_moe_layer from colossal_moe.models.mixtral_layer import replace_moe_layer
from torch.utils.data import Dataset from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from tqdm import tqdm from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@ -14,29 +14,7 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.moe import MOE_MANAGER
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.utils import get_current_device
import argparse
import os
from functools import partial
from typing import Dict
import torch
import torch.distributed as dist
from datasets import load_dataset
from huggingface_hub import snapshot_download
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -88,9 +66,7 @@ def parse_args():
choices=["fp32", "bf16", "fp16"], choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.", help="The mixed precision training.",
) )
parser.add_argument( parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
"--seed", type=int, default=42, help="A seed for reproducible training."
)
# kernel # kernel
parser.add_argument( parser.add_argument(
@ -147,11 +123,7 @@ def main():
config.num_local_experts = 1 # dont change this. it will not affect model config.num_local_experts = 1 # dont change this. it will not affect model
with skip_init(): with skip_init():
model = MixtralForCausalLM(config) model = MixtralForCausalLM(config)
model = ( model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model.to(torch.bfloat16)
if args.precision == "bf16"
else model.to(torch.float16)
)
model = model.to(get_current_device()) model = model.to(get_current_device())
coordinator.print_on_master(f"Finish init model with config:\n{config}") coordinator.print_on_master(f"Finish init model with config:\n{config}")

View File

@ -1,4 +1,3 @@
import importlib
import os import os
import shutil import shutil
import sys import sys
@ -6,7 +5,9 @@ import sys
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformers.models.llama import LlamaConfig from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
@ -14,9 +15,6 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
sys.path.append( sys.path.append(
os.path.join( os.path.join(

View File

@ -1,10 +1,15 @@
import argparse import argparse
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.models.mixtral_layer import replace_moe_layer from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from huggingface_hub import snapshot_download
# from colossalai.nn.optimizer import HybridAdam
from torch.optim import Adam as HybridAdam
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -14,37 +19,16 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.moe import MOE_MANAGER, apply_load_balance from colossalai.moe import MOE_MANAGER, apply_load_balance
# from colossalai.nn.optimizer import HybridAdam
from torch.optim import Adam as HybridAdam
from colossalai.utils import get_current_device
import argparse
import os
from functools import partial
from typing import Dict
import torch
import torch.distributed as dist
from datasets import load_dataset
from huggingface_hub import snapshot_download
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
def move_to_cuda(batch, device): def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()} return {k: v.to(device) for k, v in batch.items()}
def load_ckpt(repo_name: str, model, booster: Booster): def load_ckpt(repo_name: str, model, booster: Booster):
ckpt_path = snapshot_download(repo_name) ckpt_path = snapshot_download(repo_name)
# single ckpt # single ckpt
@ -284,7 +268,9 @@ def main():
model = model.to(get_current_device()) model = model.to(get_current_device())
replace_moe_layer(model) replace_moe_layer(model)
# torch.set_default_tensor_type(torch.float32) # torch.set_default_tensor_type(torch.float32)
print(f"0-2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB") print(
f"0-2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
)
coordinator.print_on_master(f"Finish init model with config:\n{config}") coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing # Enable gradient checkpointing
@ -298,7 +284,9 @@ def main():
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
) )
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB") print(
f"1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
)
# Set optimizer # Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
@ -307,10 +295,14 @@ def main():
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"2-1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB") print(
f"2-1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
)
load_ckpt("mistralai/Mixtral-8x7B-v0.1", model, booster) load_ckpt("mistralai/Mixtral-8x7B-v0.1", model, booster)
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB") print(
f"2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
@ -348,11 +340,15 @@ def main():
data = move_to_cuda(data, torch.cuda.current_device()) data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data) outputs = model(**data)
loss = outputs["loss"] loss = outputs["loss"]
print(f"3 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB") print(
f"3 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
)
# Backward # Backward
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
print(f"4 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB") print(
f"4 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
)
pbar.set_postfix({"loss": loss.item()}) pbar.set_postfix({"loss": loss.item()})