mirror of https://github.com/InternLM/InternLM
support reduce scatter memory pool
parent
4742271154
commit
ed7232777a
|
@ -57,7 +57,7 @@ data = dict(
|
|||
# defaults to 0, means disable evaluate
|
||||
valid_every=50,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=50,
|
||||
total_steps=20,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
|
|
|
@ -5,7 +5,7 @@ SEQ_LEN = 4096
|
|||
HIDDEN_SIZE = 6144
|
||||
NUM_ATTENTION_HEAD = 48
|
||||
MLP_RATIO = 8 / 3
|
||||
NUM_LAYER = 40
|
||||
NUM_LAYER = 60
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
|
@ -51,7 +51,7 @@ data = dict(
|
|||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=4,
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
|
@ -161,8 +161,8 @@ pipeline parallel (dict):
|
|||
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=8, mode="origin_tp", overlap=False),
|
||||
zero1=dict(size=4, fsdp=False),
|
||||
tensor=dict(size=8, mode="fstp", overlap=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=True,
|
||||
)
|
||||
|
|
|
@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=8, mode="fstp"),
|
||||
tensor=dict(size=8, mode="fstp", overlap=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=True,
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch.distributed import ProcessGroup
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -148,6 +149,18 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
|
|||
async_op=async_op)
|
||||
return output, handle
|
||||
|
||||
def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||
index = check_reduce_scatter_memory_pool(size)
|
||||
output = gpc.config.reduce_scatter_memory[size]['data'][index]
|
||||
setattr(output, "index", index)
|
||||
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op)
|
||||
return output, handle
|
||||
|
||||
|
||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||
class FusedDenseFunc(torch.autograd.Function):
|
||||
|
@ -404,12 +417,13 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
# assert hasattr(bias, "_fstp_all_reduce_str")
|
||||
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
|
||||
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
||||
|
||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
|
||||
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
||||
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||
if grad_bias is not None:
|
||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
|
||||
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||
|
@ -521,3 +535,37 @@ def Silu(w1_o, w2_o):
|
|||
|
||||
|
||||
Silu = torch.jit.script(Silu)
|
||||
|
||||
def check_reduce_scatter_memory_pool(key):
|
||||
|
||||
return_idx = 0
|
||||
|
||||
# if key not in dict
|
||||
if key not in gpc.config.reduce_scatter_memory:
|
||||
gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []}
|
||||
|
||||
# if the data is empty
|
||||
if len(gpc.config.reduce_scatter_memory[key]['data']) == 0:
|
||||
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous())
|
||||
gpc.config.reduce_scatter_memory[key]['used'].append(True)
|
||||
return_idx = 0
|
||||
return return_idx
|
||||
else: # if not empty
|
||||
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']):
|
||||
if used == False:
|
||||
gpc.config.reduce_scatter_memory[key]['used'][index] = True
|
||||
return_idx = index
|
||||
return return_idx
|
||||
# if the memory pool is all used
|
||||
length = len(gpc.config.reduce_scatter_memory[key]['data'])
|
||||
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device=get_current_device()).contiguous())
|
||||
gpc.config.reduce_scatter_memory[key]['used'].append(True)
|
||||
return_idx = length
|
||||
return return_idx
|
||||
|
||||
def release_reduce_scatter_memory_pool(size, index):
|
||||
gpc.config.reduce_scatter_memory[size]['used'][index] = False
|
|
@ -10,7 +10,7 @@ from torch.optim import Optimizer
|
|||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import split_forward_gather_backward
|
||||
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
|
@ -353,7 +353,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
comm_handle.wait()
|
||||
_param.grad.add_(_grad)
|
||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
del _grad
|
||||
# del _grad
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
|
@ -395,7 +396,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
comm_handle.wait()
|
||||
_param.grad.add_(_grad)
|
||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
del _grad
|
||||
# del _grad
|
||||
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||
|
|
|
@ -51,7 +51,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
|||
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.train.utils import create_param_groups
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.common import DummyProfile, get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
||||
|
@ -123,7 +123,8 @@ def initialize_model():
|
|||
mlp_ratio = gpc.config.MLP_RATIO
|
||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
||||
size_key = [(3 * hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (hidden_size, hidden_size)]
|
||||
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)]
|
||||
module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3']
|
||||
for i in range(2):
|
||||
weight = {}
|
||||
|
@ -131,21 +132,26 @@ def initialize_model():
|
|||
if name == 'Wqkv':
|
||||
weight[name] = torch.zeros((3 * hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device='cuda').contiguous()
|
||||
device=get_current_device()).contiguous()
|
||||
elif name == 'out_proj':
|
||||
weight[name] = torch.zeros((hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device='cuda').contiguous()
|
||||
device=get_current_device()).contiguous()
|
||||
elif name == 'w1' or name == 'w2':
|
||||
weight[name] = torch.zeros((mlp_hidden_size, hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device='cuda').contiguous()
|
||||
device=get_current_device()).contiguous()
|
||||
else:
|
||||
weight[name] = torch.zeros((hidden_size, mlp_hidden_size),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
device='cuda').contiguous()
|
||||
device=get_current_device()).contiguous()
|
||||
block_memory[i] = weight
|
||||
reduce_scatter_memory = {}
|
||||
for key in size_key:
|
||||
reduce_scatter_memory[key] = {'data': [], 'used': []}
|
||||
|
||||
gpc.config.block_memory = block_memory
|
||||
gpc.config.reduce_scatter_memory = reduce_scatter_memory
|
||||
|
||||
return model
|
||||
|
||||
|
|
1
train.py
1
train.py
|
@ -299,6 +299,7 @@ def main(args):
|
|||
|
||||
if gpc.config.fstp_handler is not None:
|
||||
gpc.config.fstp_handler.zero_const_pool = {}
|
||||
gpc.config.fstp_handler.reduce_scatter_memory = {}
|
||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
|
Loading…
Reference in New Issue