mirror of https://github.com/hpcaitech/ColossalAI
update
parent
aa2e091dc6
commit
23341687ed
|
@ -2,52 +2,14 @@ import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
import copy
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from shutil import rmtree
|
|
||||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
|
|
||||||
from colossalai.checkpoint_io.utils import (
|
|
||||||
StateDictSharder,
|
|
||||||
gather_distributed_param,
|
|
||||||
get_model_base_filenames,
|
|
||||||
get_optimizer_base_filenames,
|
|
||||||
is_safetensors_available,
|
|
||||||
load_shard_state_dict,
|
|
||||||
load_state_dict,
|
|
||||||
load_state_dict_into_model,
|
|
||||||
load_states_into_optimizer,
|
|
||||||
save_config_file,
|
|
||||||
save_param_groups,
|
|
||||||
save_state_dict,
|
|
||||||
save_state_dict_shards,
|
|
||||||
sharded_optimizer_loading_epilogue,
|
|
||||||
)
|
|
||||||
from colossalai.interface import OptimizerWrapper
|
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
|
||||||
from colossalai.tensor.moe_tensor.api import (
|
|
||||||
get_dp_group,
|
|
||||||
get_dp_rank,
|
|
||||||
get_dp_size,
|
|
||||||
get_ep_group,
|
|
||||||
get_ep_rank,
|
|
||||||
get_ep_size,
|
|
||||||
is_moe_tensor,
|
|
||||||
)
|
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||||
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
||||||
from colossalai.moe import MoECheckpintIO
|
from colossalai.moe import MoECheckpintIO
|
||||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, is_moe_tensor
|
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
|
||||||
|
|
||||||
|
|
||||||
class MixtralMoECheckpointIO(MoECheckpintIO):
|
class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||||
|
@ -62,8 +24,8 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||||
model_param_dict = dict(model.named_parameters())
|
model_param_dict = dict(model.named_parameters())
|
||||||
for name, param in list(state_dict.items()):
|
for name, param in list(state_dict.items()):
|
||||||
if ".gate.weight" in name:
|
if ".gate.weight" in name:
|
||||||
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
|
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
|
||||||
state_dict[new_name] = state_dict.pop(name)
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
elif ".experts." in name:
|
elif ".experts." in name:
|
||||||
# if is moe tensor
|
# if is moe tensor
|
||||||
# in our moe module, expert is cat as one tensor
|
# in our moe module, expert is cat as one tensor
|
||||||
|
@ -94,7 +56,7 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||||
state_dict[model_param_name] = new_param
|
state_dict[model_param_name] = new_param
|
||||||
state_dict.pop(name)
|
state_dict.pop(name)
|
||||||
else:
|
else:
|
||||||
new_name = "module." + name
|
new_name = "module." + name
|
||||||
state_dict[new_name] = state_dict.pop(name)
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
|
|
||||||
for name, param in list(state_dict.items()):
|
for name, param in list(state_dict.items()):
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralDecoderLayer
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.moe import SparseMLP
|
from colossalai.moe import SparseMLP
|
||||||
from colossalai.tensor.moe_tensor.api import get_ep_rank, is_moe_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralSparseMLP:
|
class MixtralSparseMLP:
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
|
||||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
from torch.utils.data import Dataset
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
from tqdm import tqdm
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
|
@ -14,29 +14,7 @@ import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.moe import MOE_MANAGER
|
||||||
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
from functools import partial
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from datasets import load_dataset
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers import T5Tokenizer
|
|
||||||
from transformers.models.llama import LlamaConfig
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.booster import Booster
|
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
|
||||||
from colossalai.cluster import DistCoordinator
|
|
||||||
from colossalai.moe.layers import apply_load_balance
|
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
|
||||||
from colossalai.moe.utils import skip_init
|
from colossalai.moe.utils import skip_init
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -88,9 +66,7 @@ def parse_args():
|
||||||
choices=["fp32", "bf16", "fp16"],
|
choices=["fp32", "bf16", "fp16"],
|
||||||
help="The mixed precision training.",
|
help="The mixed precision training.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||||
"--seed", type=int, default=42, help="A seed for reproducible training."
|
|
||||||
)
|
|
||||||
|
|
||||||
# kernel
|
# kernel
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -147,11 +123,7 @@ def main():
|
||||||
config.num_local_experts = 1 # dont change this. it will not affect model
|
config.num_local_experts = 1 # dont change this. it will not affect model
|
||||||
with skip_init():
|
with skip_init():
|
||||||
model = MixtralForCausalLM(config)
|
model = MixtralForCausalLM(config)
|
||||||
model = (
|
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
|
||||||
model.to(torch.bfloat16)
|
|
||||||
if args.precision == "bf16"
|
|
||||||
else model.to(torch.float16)
|
|
||||||
)
|
|
||||||
model = model.to(get_current_device())
|
model = model.to(get_current_device())
|
||||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import importlib
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
@ -6,7 +5,9 @@ import sys
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from transformers.models.llama import LlamaConfig
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
@ -14,9 +15,6 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
|
||||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
|
||||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
|
||||||
|
|
||||||
sys.path.append(
|
sys.path.append(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
|
||||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from torch.optim import Adam as HybridAdam
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
@ -14,37 +19,16 @@ import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
|
||||||
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
||||||
# from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from torch.optim import Adam as HybridAdam
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
from functools import partial
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from datasets import load_dataset
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers import T5Tokenizer
|
|
||||||
from transformers.models.llama import LlamaConfig
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.booster import Booster
|
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
|
||||||
from colossalai.cluster import DistCoordinator
|
|
||||||
from colossalai.moe.layers import apply_load_balance
|
from colossalai.moe.layers import apply_load_balance
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.moe.utils import skip_init
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def move_to_cuda(batch, device):
|
def move_to_cuda(batch, device):
|
||||||
return {k: v.to(device) for k, v in batch.items()}
|
return {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
def load_ckpt(repo_name: str, model, booster: Booster):
|
def load_ckpt(repo_name: str, model, booster: Booster):
|
||||||
ckpt_path = snapshot_download(repo_name)
|
ckpt_path = snapshot_download(repo_name)
|
||||||
# single ckpt
|
# single ckpt
|
||||||
|
@ -284,7 +268,9 @@ def main():
|
||||||
model = model.to(get_current_device())
|
model = model.to(get_current_device())
|
||||||
replace_moe_layer(model)
|
replace_moe_layer(model)
|
||||||
# torch.set_default_tensor_type(torch.float32)
|
# torch.set_default_tensor_type(torch.float32)
|
||||||
print(f"0-2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
print(
|
||||||
|
f"0-2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
|
||||||
|
)
|
||||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||||
|
|
||||||
# Enable gradient checkpointing
|
# Enable gradient checkpointing
|
||||||
|
@ -298,7 +284,9 @@ def main():
|
||||||
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(f"1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
print(
|
||||||
|
f"1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
|
||||||
|
)
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||||
|
@ -307,10 +295,14 @@ def main():
|
||||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(f"2-1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
print(
|
||||||
|
f"2-1 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
|
||||||
|
)
|
||||||
load_ckpt("mistralai/Mixtral-8x7B-v0.1", model, booster)
|
load_ckpt("mistralai/Mixtral-8x7B-v0.1", model, booster)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(f"2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
print(
|
||||||
|
f"2 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
|
||||||
|
)
|
||||||
|
|
||||||
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||||
|
@ -348,11 +340,15 @@ def main():
|
||||||
data = move_to_cuda(data, torch.cuda.current_device())
|
data = move_to_cuda(data, torch.cuda.current_device())
|
||||||
outputs = model(**data)
|
outputs = model(**data)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
print(f"3 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
print(
|
||||||
|
f"3 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
|
||||||
|
)
|
||||||
|
|
||||||
# Backward
|
# Backward
|
||||||
booster.backward(loss, optimizer)
|
booster.backward(loss, optimizer)
|
||||||
print(f"4 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB")
|
print(
|
||||||
|
f"4 param num: {sum(p.numel() for p in model.parameters())/ 1000.0 ** 3}GB, memory: {torch.cuda.memory_allocated()/ 1000.0 ** 3}GB"
|
||||||
|
)
|
||||||
|
|
||||||
pbar.set_postfix({"loss": loss.item()})
|
pbar.set_postfix({"loss": loss.item()})
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue