support reduce scatter memory pool

pull/456/head
yingtongxiong 2023-10-20 10:35:45 +08:00
parent 4742271154
commit ed7232777a
7 changed files with 74 additions and 17 deletions

View File

@ -57,7 +57,7 @@ data = dict(
# defaults to 0, means disable evaluate # defaults to 0, means disable evaluate
valid_every=50, valid_every=50,
pack_sample_into_one=False, pack_sample_into_one=False,
total_steps=50, total_steps=20,
skip_batches="", skip_batches="",
rampup_batch_size="", rampup_batch_size="",
# Datasets with less than 50 rows will be discarded # Datasets with less than 50 rows will be discarded

View File

@ -5,7 +5,7 @@ SEQ_LEN = 4096
HIDDEN_SIZE = 6144 HIDDEN_SIZE = 6144
NUM_ATTENTION_HEAD = 48 NUM_ATTENTION_HEAD = 48
MLP_RATIO = 8 / 3 MLP_RATIO = 8 / 3
NUM_LAYER = 40 NUM_LAYER = 60
VOCAB_SIZE = 103168 VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
@ -51,7 +51,7 @@ data = dict(
# micro_num means the number of micro_batch contained in one gradient update # micro_num means the number of micro_batch contained in one gradient update
micro_num=4, micro_num=4,
# packed_length = micro_bsz * SEQ_LEN # packed_length = micro_bsz * SEQ_LEN
micro_bsz=4, micro_bsz=2,
# defaults to the value of micro_num # defaults to the value of micro_num
valid_micro_num=4, valid_micro_num=4,
# defaults to 0, means disable evaluate # defaults to 0, means disable evaluate
@ -161,8 +161,8 @@ pipeline parallel (dict):
sequence parallel (bool): enable/disable sequence parallel, defaults to False. sequence parallel (bool): enable/disable sequence parallel, defaults to False.
""" """
parallel = dict( parallel = dict(
zero1=dict(size=-1, fsdp=False), zero1=dict(size=4, fsdp=False),
tensor=dict(size=8, mode="origin_tp", overlap=False), tensor=dict(size=8, mode="fstp", overlap=True),
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True, sequence_parallel=True,
) )

View File

@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
""" """
parallel = dict( parallel = dict(
zero1=dict(size=-1, fsdp=False), zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp"), tensor=dict(size=8, mode="fstp", overlap=True),
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True, sequence_parallel=True,
) )

View File

@ -14,6 +14,7 @@ from torch.distributed import ProcessGroup
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.common import get_current_device
logger = get_logger(__file__) logger = get_logger(__file__)
@ -148,6 +149,18 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
async_op=async_op) async_op=async_op)
return output, handle return output, handle
def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
size = (input_.shape[0] // world_size, *input_.shape[1:])
index = check_reduce_scatter_memory_pool(size)
output = gpc.config.reduce_scatter_memory[size]['data'][index]
setattr(output, "index", index)
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
group=process_group,
async_op=async_op)
return output, handle
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFunc(torch.autograd.Function): class FusedDenseFunc(torch.autograd.Function):
@ -404,12 +417,13 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
# assert hasattr(bias, "_fstp_all_reduce_str") # assert hasattr(bias, "_fstp_all_reduce_str")
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async) # all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) # grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
assert hasattr(weight, "_fstp_reduce_scatter_str") assert hasattr(weight, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async) all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device) grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
if grad_bias is not None: if grad_bias is not None:
grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True) grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
assert hasattr(bias, "_fstp_reduce_scatter_str") assert hasattr(bias, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async) all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device) grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
@ -521,3 +535,37 @@ def Silu(w1_o, w2_o):
Silu = torch.jit.script(Silu) Silu = torch.jit.script(Silu)
def check_reduce_scatter_memory_pool(key):
return_idx = 0
# if key not in dict
if key not in gpc.config.reduce_scatter_memory:
gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []}
# if the data is empty
if len(gpc.config.reduce_scatter_memory[key]['data']) == 0:
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device()).contiguous())
gpc.config.reduce_scatter_memory[key]['used'].append(True)
return_idx = 0
return return_idx
else: # if not empty
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']):
if used == False:
gpc.config.reduce_scatter_memory[key]['used'][index] = True
return_idx = index
return return_idx
# if the memory pool is all used
length = len(gpc.config.reduce_scatter_memory[key]['data'])
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device()).contiguous())
gpc.config.reduce_scatter_memory[key]['used'].append(True)
return_idx = length
return return_idx
def release_reduce_scatter_memory_pool(size, index):
gpc.config.reduce_scatter_memory[size]['used'][index] = False

View File

@ -10,7 +10,7 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.utils import split_forward_gather_backward from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
from internlm.monitor import send_alert_message from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import ( from internlm.solver.optimizer.store import (
BucketStore, BucketStore,
@ -353,7 +353,8 @@ class HybridZeroOptimizer(BaseOptimizer):
comm_handle.wait() comm_handle.wait()
_param.grad.add_(_grad) _param.grad.add_(_grad)
# self._fstp_handler.reduce_scatter_handlers[key] = None # self._fstp_handler.reduce_scatter_handlers[key] = None
del _grad # del _grad
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
del self._fstp_handler.reduce_scatter_handlers[key] del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers assert key in self._fstp_handler.reduce_scatter_handlers
@ -395,7 +396,8 @@ class HybridZeroOptimizer(BaseOptimizer):
comm_handle.wait() comm_handle.wait()
_param.grad.add_(_grad) _param.grad.add_(_grad)
# self._fstp_handler.reduce_scatter_handlers[key] = None # self._fstp_handler.reduce_scatter_handlers[key] = None
del _grad # del _grad
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
del self._fstp_handler.reduce_scatter_handlers[key] del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers assert key in self._fstp_handler.reduce_scatter_handlers

View File

@ -51,7 +51,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.train.utils import create_param_groups from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile from internlm.utils.common import DummyProfile, get_current_device
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
@ -123,7 +123,8 @@ def initialize_model():
mlp_ratio = gpc.config.MLP_RATIO mlp_ratio = gpc.config.MLP_RATIO
mlp_hidden_size = int(hidden_size * mlp_ratio) mlp_hidden_size = int(hidden_size * mlp_ratio)
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
size_key = [(3 * hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (hidden_size, hidden_size)] world_size = gpc.get_world_size(ParallelMode.TENSOR)
size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)]
module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3'] module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3']
for i in range(2): for i in range(2):
weight = {} weight = {}
@ -131,21 +132,26 @@ def initialize_model():
if name == 'Wqkv': if name == 'Wqkv':
weight[name] = torch.zeros((3 * hidden_size, hidden_size), weight[name] = torch.zeros((3 * hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half), dtype=gpc.config.model.get("dtype", torch.half),
device='cuda').contiguous() device=get_current_device()).contiguous()
elif name == 'out_proj': elif name == 'out_proj':
weight[name] = torch.zeros((hidden_size, hidden_size), weight[name] = torch.zeros((hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half), dtype=gpc.config.model.get("dtype", torch.half),
device='cuda').contiguous() device=get_current_device()).contiguous()
elif name == 'w1' or name == 'w2': elif name == 'w1' or name == 'w2':
weight[name] = torch.zeros((mlp_hidden_size, hidden_size), weight[name] = torch.zeros((mlp_hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half), dtype=gpc.config.model.get("dtype", torch.half),
device='cuda').contiguous() device=get_current_device()).contiguous()
else: else:
weight[name] = torch.zeros((hidden_size, mlp_hidden_size), weight[name] = torch.zeros((hidden_size, mlp_hidden_size),
dtype=gpc.config.model.get("dtype", torch.half), dtype=gpc.config.model.get("dtype", torch.half),
device='cuda').contiguous() device=get_current_device()).contiguous()
block_memory[i] = weight block_memory[i] = weight
reduce_scatter_memory = {}
for key in size_key:
reduce_scatter_memory[key] = {'data': [], 'used': []}
gpc.config.block_memory = block_memory gpc.config.block_memory = block_memory
gpc.config.reduce_scatter_memory = reduce_scatter_memory
return model return model

View File

@ -299,6 +299,7 @@ def main(args):
if gpc.config.fstp_handler is not None: if gpc.config.fstp_handler is not None:
gpc.config.fstp_handler.zero_const_pool = {} gpc.config.fstp_handler.zero_const_pool = {}
gpc.config.fstp_handler.reduce_scatter_memory = {}
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()