diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 46f38390a..378bbd2fc 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -9,7 +9,6 @@ import torch import torch.distributed as dist import torch.nn as nn from torch import Tensor, inf -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -21,7 +20,6 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.tensor.moe_tensor.api import is_moe_tensor from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -76,7 +74,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm forced_dtype: Optional[torch.dtype] = None, - moe_extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -102,16 +99,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._local_rank = dist.get_rank(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg) - # extra dp - # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. - # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. - # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. - # And moe working and master param are split by extra dp pg. - self.moe_extra_dp_pg = moe_extra_dp_process_group - if self.moe_extra_dp_pg is not None: - self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) - self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) - # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() @@ -143,12 +130,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) self._bucket_store = BucketStore(self.dp_pg) - # moe param should not be stored in working_groups - # because they have different parallel strategy - # so we need to store them separately in param_groups - # instead of working_groups - self.working_moe_params = list() - # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access @@ -156,11 +137,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): group_params = list() for param in param_group["params"]: if param.requires_grad: - if self.moe_extra_dp_pg is None: - # skip moe param - if is_moe_tensor(param): - self.working_moe_params.append(param) - continue group_params.append(param) # add the working params to working_param_groups for bookkeeping @@ -174,25 +150,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # managed by this data parallel rank param_group["params"] = master_param_current_rank - # if there are moe params, store in addtional group in optim - if len(self.working_moe_params) > 0: - self._sync_master_param = False - param_group = dict() - # create fp32 master param - for key, value in self.optim.param_groups[0].items(): - if key != "params": - param_group[key] = value - self.master_moe_params = [] - for param in self.working_moe_params: - self.master_moe_params.append(param.clone().to(torch.float32).detach()) - # create mapping from master to working for optimizer io - self.moe_master_to_working_map = {} - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param - # add to optim - param_group["params"] = self.master_moe_params - self.optim.param_groups.append(param_group) - # initialize communication stream for # communication-computation overlapping if self._overlap_communication: @@ -256,12 +213,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): else: padding_param = param.data.view(-1) - if self.moe_extra_dp_pg is not None and is_moe_tensor(param): - splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size) - splited_params = splited_params[self.moe_extra_dp_pg_rank] - else: - splited_params = padding_param.split(padding_param.numel() // self._world_size) - splited_params = splited_params[self._local_rank] + splited_params = padding_param.split(padding_param.numel() // self._world_size) + splited_params = splited_params[self._local_rank] # use fp32 when master_weights is True if self._master_weights is True: @@ -301,43 +254,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self._bucket_store.num_elements_in_bucket() > 0: self._bucket_store.build_grad_in_bucket() - if self.moe_extra_dp_pg is None: - flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size - else: - # record moe and non moe param - moe_list = [] - for param in self._bucket_store._param_list: - moe_list.append(is_moe_tensor(param)) - - # divide them into different groups - moe_grad_list = [] - non_moe_grad_list = [] - for grad_list in self._bucket_store._grad_in_bucket.values(): - non_moe_cur_grad = [] - moe_cur_grad = [] - for i in range(len(grad_list)): - if moe_list[i] == True: - moe_cur_grad.append(grad_list[i]) - else: - non_moe_cur_grad.append(grad_list[i]) - if len(moe_cur_grad) > 0: - moe_grad_list.append(moe_cur_grad) - if len(non_moe_cur_grad) > 0: - non_moe_grad_list.append(non_moe_cur_grad) - - if len(non_moe_grad_list) > 0: - non_moe_flat_grads = [] - for grad_list in non_moe_grad_list: - non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) - non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) - non_moe_flat_grads /= self._world_size - - if len(moe_grad_list) > 0: - moe_flat_grads = [] - for grad_list in moe_grad_list: - moe_flat_grads.append(_flatten_dense_tensors(grad_list)) - moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size # ready to add other tensors to bucket self._bucket_store.reset_num_elements_in_bucket() @@ -345,13 +263,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self._overlap_communication: stream = self._comm_stream # in case of the memory being reused in the default stream - if self.moe_extra_dp_pg is None: - flat_grads.record_stream(stream) - else: - if len(non_moe_grad_list) > 0: - non_moe_flat_grads.record_stream(stream) - if len(moe_grad_list) > 0: - moe_flat_grads.record_stream(stream) + flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(get_accelerator().current_stream()) else: @@ -360,84 +272,29 @@ class LowLevelZeroOptimizer(OptimizerWrapper): with get_accelerator().stream(stream): group_id = self._bucket_store.current_group_id - if self.moe_extra_dp_pg is None: - grad_dtype = flat_grads.dtype - if self._communication_dtype is not None: - flat_grads = flat_grads.to(self._communication_dtype) + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - if self.moe_extra_dp_pg is None: - dist.all_reduce(flat_grads, group=self.dp_pg) - if flat_grads.dtype != grad_dtype: - flat_grads = flat_grads.to(grad_dtype) + dist.all_reduce(flat_grads, group=self.dp_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) - grad_in_bucket = self._bucket_store.get_grad() - self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) - - # sync extra zero group - else: - # sync non moe param in global dp group - if len(non_moe_grad_list) > 0: - dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) - flat_grads_per_rank = non_moe_flat_grads.split( - non_moe_flat_grads.numel() // self._world_size - ) - self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) - - # sync moe param only in zero group - if len(moe_grad_list) > 0: - dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg) - flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) - self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() + self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) else: - if self.moe_extra_dp_pg is None: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) - else: - # categorize moe and non moe param - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - moe_grad_in_bucket_current_rank = [] - non_moe_grad_in_bucket_current_rank = [] - for idx, grad in enumerate(grad_in_bucket_current_rank): - if moe_list[idx] == True: - moe_grad_in_bucket_current_rank.append(grad) - else: - non_moe_grad_in_bucket_current_rank.append(grad) - - if len(non_moe_grad_list) > 0: - flat_grads_list = list( - non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size) - ) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - self._update_partitoned_grad( - non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1 - ) - - if len(moe_grad_list) > 0: - flat_grads_list = list( - moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) - ) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg) - param_slice = self._world_size // self.moe_extra_dp_pg_size - recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) - for split_recieved_grad in recieved_grad: - split_recieved_grad = _unflatten_dense_tensors( - split_recieved_grad, moe_grad_in_bucket_current_rank - ) - for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank): - param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(real_grad, param_slice, group_id, param_id) + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) self._bucket_store.reset() @@ -578,20 +435,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # else the splited grad should be attached to the splited param grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: - # moe hybrid zero - if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): - real_working_params[group_id].append(working_param) - if self._partition_grads: - grad = grads - else: - param_slice = self._world_size // self.moe_extra_dp_pg_size - grad = grads[ - self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice - ] - grad = flatten(grad) - else: - real_working_params[group_id].append(working_param) - grad = grads[grad_index] + real_working_params[group_id].append(working_param) + grad = grads[grad_index] # no need to copy fp32 grad if master_weights is False if self._master_weights: grad = grad.to(splited_param.dtype).to(splited_param.device) @@ -609,26 +454,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # update the params in the optimizer self.optim.param_groups[group_id]["params"] = real_master_params[group_id] - # update param for moe ep - # move grad to master param and compute norm - if len(self.working_moe_params) > 0: - moe_grads = [] - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - if master_moe_param.grad is not None: - raise RuntimeError("Moe param should not have grad here") - grad = working_moe_param.grad - # no need to copy fp32 grad if master_weights is False - if self._master_weights: - grad = grad.to(master_moe_param.dtype).to(master_moe_param.device) - master_moe_param.grad = grad - working_moe_param.grad = None - moe_grads.append(grad) - grad_partition_groups.append(grad) - norm_group = self._compute_grad_norm(gradients=moe_grads) - norm_groups.append(norm_group) - self.optim.param_groups[-1]["params"] = self.master_moe_params - del moe_grads - # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) @@ -636,14 +461,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # update the parameters self.optim.step() - # release moe grad - if len(self.working_moe_params) > 0: - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - master_moe_param.grad = None - working_moe_param.data = ( - master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach() - ) - # release the grad grad_partition_groups = [] for group_id in range(self.num_param_groups): @@ -655,20 +472,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self.moe_extra_dp_pg_size) - ] - dist.all_gather( - all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg - ) - else: - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) - for _ in range(self._world_size) - ] - dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) + all_splited_param = [ + torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -802,16 +609,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": working_param = self._param_store.master_to_working_param[id(param)] - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): - gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) - ] - dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg) - else: - gather_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) - ] - dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg) + gather_tensor = [ + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg) param_state = ( torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -836,12 +637,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): - v_list = v.split(v.numel() // self.moe_extra_dp_pg_size) - zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone() - else: - v_list = v.split(v.numel() // self._world_size) - zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() + v_list = v.split(v.numel() // self._world_size) + zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -873,16 +670,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - if self.moe_extra_dp_pg is not None and is_moe_tensor(v): - state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) - ] - dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg) - else: - state_tensor = [ - torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) - ] - dist.all_gather(state_tensor, v.to(device), group=self.dp_pg) + state_tensor = [torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)] + dist.all_gather(state_tensor, v.to(device), group=self.dp_pg) state_tensor = ( torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -913,18 +702,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - if self.moe_extra_dp_pg is not None and is_moe_tensor(p): - master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) - else: - master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) - if hasattr(self, "master_moe_params"): - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - master_moe_param.copy_(working_moe_param) + master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: - if hasattr(self, "moe_master_to_working_map"): - return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} return self._param_store.master_to_working_param