mirror of https://github.com/hpcaitech/ColossalAI
update optim
parent
f037583bd2
commit
8ca8cf8ec3
|
@ -126,6 +126,12 @@ def _test_moe_checkpoint(parallel):
|
|||
model1, booster1, optim1 = get_model(parallel)
|
||||
model2, booster2, optim2 = get_model(parallel)
|
||||
# param ckpt
|
||||
# check not equal
|
||||
try:
|
||||
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
||||
raise AssertionError("state_dict should not be equal")
|
||||
except:
|
||||
pass
|
||||
# shard
|
||||
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
|
||||
booster2.load_model(model2, "./tmp_ckpt1")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import argparse
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
|
@ -10,7 +10,6 @@ from torch.utils.data import Dataset
|
|||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -19,11 +18,11 @@ from colossalai.cluster import DistCoordinator
|
|||
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_global_loss(loss, booster):
|
||||
global_loss = loss.clone().detach()
|
||||
|
@ -31,6 +30,7 @@ def get_global_loss(loss, booster):
|
|||
global_loss.div_(booster.plugin.dp_size)
|
||||
return global_loss
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
|
||||
self.num_samples = num_samples
|
||||
|
@ -97,7 +97,7 @@ def parse_args():
|
|||
# optim
|
||||
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
|
||||
|
||||
# lr scheduler
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
|
@ -197,7 +197,7 @@ def main():
|
|||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
|
||||
dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
|
||||
collate_fn = None
|
||||
dataloader = plugin.prepare_dataloader(
|
||||
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||
|
@ -211,7 +211,7 @@ def main():
|
|||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
|
||||
# Set lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer,
|
||||
|
@ -264,7 +264,7 @@ def main():
|
|||
if is_pp_last_stage:
|
||||
loss = outputs["loss"]
|
||||
global_loss = get_global_loss(loss, booster)
|
||||
if coordinator._local_rank == '0':
|
||||
if coordinator._local_rank == "0":
|
||||
pbar.set_postfix({"Loss": global_loss.item()})
|
||||
else:
|
||||
# Forward pass
|
||||
|
|
|
@ -334,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
|
||||
):
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
@ -349,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
|
@ -373,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
# ep extra group
|
||||
if MOE_MANAGER.parallel == "EP":
|
||||
# ep param group
|
||||
if len(optimizer.optim.param_groups) > len(saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1][
|
||||
"params"
|
||||
] # Only keep the parameters kept by current pipeline stage.
|
||||
for param in new_pg["params"]:
|
||||
param.data = param.data.to(torch.float32)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
|
@ -391,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
@ -410,12 +408,14 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
device = param.device
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
param,
|
||||
working_param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
|
@ -578,6 +578,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
|
@ -620,6 +622,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
|
@ -725,6 +728,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
|
|
|
@ -175,12 +175,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if len(self.working_moe_params) > 0:
|
||||
self._sync_master_param = False
|
||||
param_group = dict()
|
||||
# create fp32 master param
|
||||
for key, value in self.optim.param_groups[0].items():
|
||||
if key != "params":
|
||||
param_group[key] = value
|
||||
self.master_moe_params = []
|
||||
for param in self.working_moe_params:
|
||||
self.master_moe_params.append(param.clone().to(torch.float32).detach())
|
||||
# create mapping from master to working for optimizer io
|
||||
self.moe_master_to_working_map = {}
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
|
||||
# add to optim
|
||||
param_group["params"] = self.master_moe_params
|
||||
self.optim.param_groups.append(param_group)
|
||||
|
||||
|
|
Loading…
Reference in New Issue