diff --git a/configs/20B_sft.py b/configs/20B_sft.py index 5a9021b..13e68b2 100644 --- a/configs/20B_sft.py +++ b/configs/20B_sft.py @@ -57,7 +57,7 @@ data = dict( # defaults to 0, means disable evaluate valid_every=50, pack_sample_into_one=False, - total_steps=50, + total_steps=20, skip_batches="", rampup_batch_size="", # Datasets with less than 50 rows will be discarded diff --git a/configs/30B_sft.py b/configs/30B_sft.py index ec04048..8bde057 100644 --- a/configs/30B_sft.py +++ b/configs/30B_sft.py @@ -5,7 +5,7 @@ SEQ_LEN = 4096 HIDDEN_SIZE = 6144 NUM_ATTENTION_HEAD = 48 MLP_RATIO = 8 / 3 -NUM_LAYER = 40 +NUM_LAYER = 60 VOCAB_SIZE = 103168 MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" @@ -51,7 +51,7 @@ data = dict( # micro_num means the number of micro_batch contained in one gradient update micro_num=4, # packed_length = micro_bsz * SEQ_LEN - micro_bsz=4, + micro_bsz=2, # defaults to the value of micro_num valid_micro_num=4, # defaults to 0, means disable evaluate @@ -161,8 +161,8 @@ pipeline parallel (dict): sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, mode="origin_tp", overlap=False), + zero1=dict(size=4, fsdp=False), + tensor=dict(size=8, mode="fstp", overlap=True), pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=True, ) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 106548a..6ea8b96 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, mode="fstp"), + tensor=dict(size=8, mode="fstp", overlap=True), pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=True, ) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 5b4018c..2667efe 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -14,6 +14,7 @@ from torch.distributed import ProcessGroup from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.utils.logger import get_logger +from internlm.utils.common import get_current_device logger = get_logger(__file__) @@ -148,6 +149,18 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo async_op=async_op) return output, handle +def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + assert input_.shape[0] % world_size == 0 + size = (input_.shape[0] // world_size, *input_.shape[1:]) + index = check_reduce_scatter_memory_pool(size) + output = gpc.config.reduce_scatter_memory[size]['data'][index] + setattr(output, "index", index) + handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), + group=process_group, + async_op=async_op) + return output, handle + # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py class FusedDenseFunc(torch.autograd.Function): @@ -404,12 +417,13 @@ class FSTPFusedDenseFunc(torch.autograd.Function): # assert hasattr(bias, "_fstp_all_reduce_str") # all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async) # grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) - grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True) + + grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True) assert hasattr(weight, "_fstp_reduce_scatter_str") all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) if grad_bias is not None: - grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) + grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True) assert hasattr(bias, "_fstp_reduce_scatter_str") all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) @@ -521,3 +535,37 @@ def Silu(w1_o, w2_o): Silu = torch.jit.script(Silu) + +def check_reduce_scatter_memory_pool(key): + + return_idx = 0 + + # if key not in dict + if key not in gpc.config.reduce_scatter_memory: + gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []} + + # if the data is empty + if len(gpc.config.reduce_scatter_memory[key]['data']) == 0: + gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key, + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device()).contiguous()) + gpc.config.reduce_scatter_memory[key]['used'].append(True) + return_idx = 0 + return return_idx + else: # if not empty + for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']): + if used == False: + gpc.config.reduce_scatter_memory[key]['used'][index] = True + return_idx = index + return return_idx + # if the memory pool is all used + length = len(gpc.config.reduce_scatter_memory[key]['data']) + gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key, + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device()).contiguous()) + gpc.config.reduce_scatter_memory[key]['used'].append(True) + return_idx = length + return return_idx + +def release_reduce_scatter_memory_pool(size, index): + gpc.config.reduce_scatter_memory[size]['used'][index] = False \ No newline at end of file diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d0cdd10..96a54c0 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -10,7 +10,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import split_forward_gather_backward +from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -353,7 +353,8 @@ class HybridZeroOptimizer(BaseOptimizer): comm_handle.wait() _param.grad.add_(_grad) # self._fstp_handler.reduce_scatter_handlers[key] = None - del _grad + # del _grad + release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) del self._fstp_handler.reduce_scatter_handlers[key] self._fstp_handler.reduce_scatter_handlers[key] = None assert key in self._fstp_handler.reduce_scatter_handlers @@ -395,7 +396,8 @@ class HybridZeroOptimizer(BaseOptimizer): comm_handle.wait() _param.grad.add_(_grad) # self._fstp_handler.reduce_scatter_handlers[key] = None - del _grad + # del _grad + release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index) del self._fstp_handler.reduce_scatter_handlers[key] self._fstp_handler.reduce_scatter_handlers[key] = None assert key in self._fstp_handler.reduce_scatter_handlers diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index f39e384..2816da0 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -51,7 +51,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.train.utils import create_param_groups -from internlm.utils.common import DummyProfile +from internlm.utils.common import DummyProfile, get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp @@ -123,7 +123,8 @@ def initialize_model(): mlp_ratio = gpc.config.MLP_RATIO mlp_hidden_size = int(hidden_size * mlp_ratio) mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) - size_key = [(3 * hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (hidden_size, hidden_size)] + world_size = gpc.get_world_size(ParallelMode.TENSOR) + size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)] module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3'] for i in range(2): weight = {} @@ -131,21 +132,26 @@ def initialize_model(): if name == 'Wqkv': weight[name] = torch.zeros((3 * hidden_size, hidden_size), dtype=gpc.config.model.get("dtype", torch.half), - device='cuda').contiguous() + device=get_current_device()).contiguous() elif name == 'out_proj': weight[name] = torch.zeros((hidden_size, hidden_size), dtype=gpc.config.model.get("dtype", torch.half), - device='cuda').contiguous() + device=get_current_device()).contiguous() elif name == 'w1' or name == 'w2': weight[name] = torch.zeros((mlp_hidden_size, hidden_size), dtype=gpc.config.model.get("dtype", torch.half), - device='cuda').contiguous() + device=get_current_device()).contiguous() else: weight[name] = torch.zeros((hidden_size, mlp_hidden_size), dtype=gpc.config.model.get("dtype", torch.half), - device='cuda').contiguous() + device=get_current_device()).contiguous() block_memory[i] = weight + reduce_scatter_memory = {} + for key in size_key: + reduce_scatter_memory[key] = {'data': [], 'used': []} + gpc.config.block_memory = block_memory + gpc.config.reduce_scatter_memory = reduce_scatter_memory return model diff --git a/train.py b/train.py index c972bea..41ab070 100644 --- a/train.py +++ b/train.py @@ -299,6 +299,7 @@ def main(args): if gpc.config.fstp_handler is not None: gpc.config.fstp_handler.zero_const_pool = {} + gpc.config.fstp_handler.reduce_scatter_memory = {} torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats()