From 8ca8cf8ec3b1bc729afcc3cd63677871ec20fd35 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 3 Jan 2024 11:57:23 +0800 Subject: [PATCH] update optim --- .../ColossalMoE/tests/test_moe_checkpoint.py | 6 +++++ applications/ColossalMoE/train.py | 14 +++++----- colossalai/moe/checkpoint.py | 26 +++++++++++-------- colossalai/zero/low_level/low_level_optim.py | 6 +++++ 4 files changed, 34 insertions(+), 18 deletions(-) diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py index 772cbb977..7c6012a70 100644 --- a/applications/ColossalMoE/tests/test_moe_checkpoint.py +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -126,6 +126,12 @@ def _test_moe_checkpoint(parallel): model1, booster1, optim1 = get_model(parallel) model2, booster2, optim2 = get_model(parallel) # param ckpt + # check not equal + try: + check_state_dict_equal(model1.state_dict(), model2.state_dict(), False) + raise AssertionError("state_dict should not be equal") + except: + pass # shard booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1) booster2.load_model(model2, "./tmp_ckpt1") diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index a4cf438aa..7c8807c24 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -1,7 +1,7 @@ import argparse -import torch.distributed as dist import torch +import torch.distributed as dist from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_layer import replace_moe_layer from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy @@ -10,7 +10,6 @@ from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR import colossalai from colossalai.booster import Booster @@ -19,11 +18,11 @@ from colossalai.cluster import DistCoordinator from colossalai.moe import MOE_MANAGER, apply_load_balance from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device - @torch.no_grad() def get_global_loss(loss, booster): global_loss = loss.clone().detach() @@ -31,6 +30,7 @@ def get_global_loss(loss, booster): global_loss.div_(booster.plugin.dp_size) return global_loss + class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None): self.num_samples = num_samples @@ -97,7 +97,7 @@ def parse_args(): # optim parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - + # lr scheduler parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") @@ -197,7 +197,7 @@ def main(): # Prepare tokenizer and dataloader tokenizer = AutoTokenizer.from_pretrained(args.model_name) - dataset = RandomDataset(num_samples=20, tokenizer=tokenizer) + dataset = RandomDataset(num_samples=100, tokenizer=tokenizer) collate_fn = None dataloader = plugin.prepare_dataloader( dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn @@ -211,7 +211,7 @@ def main(): weight_decay=args.weight_decay, adamw_mode=True, ) - + # Set lr scheduler lr_scheduler = CosineAnnealingWarmupLR( optimizer=optimizer, @@ -264,7 +264,7 @@ def main(): if is_pp_last_stage: loss = outputs["loss"] global_loss = get_global_loss(loss, booster) - if coordinator._local_rank == '0': + if coordinator._local_rank == "0": pbar.set_postfix({"Loss": global_loss.item()}) else: # Forward pass diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index 9e0d53aeb..9928c801d 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -334,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO): assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" def _get_param_id_from_optimizer_param( - param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None ): if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param return optimizer.param_info["param2id"][id(working_param)] @@ -349,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): master_to_working_map = optimizer.get_master_to_working_map() for pg in optimizer.optim.param_groups: for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) id_map[param_id] = param # Read checkpoint index file. @@ -373,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO): new_pg = copy.deepcopy(saved_pg) new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. updated_groups.append(new_pg) - # ep extra group - if MOE_MANAGER.parallel == "EP": + # ep param group + if len(optimizer.optim.param_groups) > len(saved_groups): new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = optimizer.optim.param_groups[-1][ - "params" - ] # Only keep the parameters kept by current pipeline stage. - for param in new_pg["params"]: - param.data = param.data.to(torch.float32) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] updated_groups.append(new_pg) optimizer.optim.__dict__.update({"param_groups": updated_groups}) @@ -391,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): for param in pg["params"]: if param is None: continue - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) if param_id not in weight_map: continue filename = weight_map[param_id] @@ -410,12 +408,14 @@ class MoECheckpintIO(HybridParallelCheckpointIO): device = param.device if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param original_shape = optimizer.param_info["param2shape"][id(working_param)] sharded_state = self.pre_load_optim( state, - param, + working_param, current_shape=working_param.shape, original_shape=original_shape, device=device, @@ -578,6 +578,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO): if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param @@ -620,6 +622,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file shard that store state tensors """ + torch.cuda.empty_cache() assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -725,6 +728,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): f"You can find where each parameters has been saved in the " f"index located at {final_index_file_path}." ) + torch.cuda.empty_cache() def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8d2346a3c..553383f4c 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -175,12 +175,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if len(self.working_moe_params) > 0: self._sync_master_param = False param_group = dict() + # create fp32 master param for key, value in self.optim.param_groups[0].items(): if key != "params": param_group[key] = value self.master_moe_params = [] for param in self.working_moe_params: self.master_moe_params.append(param.clone().to(torch.float32).detach()) + # create mapping from master to working for optimizer io + self.moe_master_to_working_map = {} + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param + # add to optim param_group["params"] = self.master_moe_params self.optim.param_groups.append(param_group)