mirror of https://github.com/InternLM/InternLM
support reduce scatter memory pool
parent
4742271154
commit
ed7232777a
|
@ -57,7 +57,7 @@ data = dict(
|
||||||
# defaults to 0, means disable evaluate
|
# defaults to 0, means disable evaluate
|
||||||
valid_every=50,
|
valid_every=50,
|
||||||
pack_sample_into_one=False,
|
pack_sample_into_one=False,
|
||||||
total_steps=50,
|
total_steps=20,
|
||||||
skip_batches="",
|
skip_batches="",
|
||||||
rampup_batch_size="",
|
rampup_batch_size="",
|
||||||
# Datasets with less than 50 rows will be discarded
|
# Datasets with less than 50 rows will be discarded
|
||||||
|
|
|
@ -5,7 +5,7 @@ SEQ_LEN = 4096
|
||||||
HIDDEN_SIZE = 6144
|
HIDDEN_SIZE = 6144
|
||||||
NUM_ATTENTION_HEAD = 48
|
NUM_ATTENTION_HEAD = 48
|
||||||
MLP_RATIO = 8 / 3
|
MLP_RATIO = 8 / 3
|
||||||
NUM_LAYER = 40
|
NUM_LAYER = 60
|
||||||
VOCAB_SIZE = 103168
|
VOCAB_SIZE = 103168
|
||||||
|
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||||
|
@ -51,7 +51,7 @@ data = dict(
|
||||||
# micro_num means the number of micro_batch contained in one gradient update
|
# micro_num means the number of micro_batch contained in one gradient update
|
||||||
micro_num=4,
|
micro_num=4,
|
||||||
# packed_length = micro_bsz * SEQ_LEN
|
# packed_length = micro_bsz * SEQ_LEN
|
||||||
micro_bsz=4,
|
micro_bsz=2,
|
||||||
# defaults to the value of micro_num
|
# defaults to the value of micro_num
|
||||||
valid_micro_num=4,
|
valid_micro_num=4,
|
||||||
# defaults to 0, means disable evaluate
|
# defaults to 0, means disable evaluate
|
||||||
|
@ -161,8 +161,8 @@ pipeline parallel (dict):
|
||||||
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=-1, fsdp=False),
|
zero1=dict(size=4, fsdp=False),
|
||||||
tensor=dict(size=8, mode="origin_tp", overlap=False),
|
tensor=dict(size=8, mode="fstp", overlap=True),
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=-1, fsdp=False),
|
zero1=dict(size=-1, fsdp=False),
|
||||||
tensor=dict(size=8, mode="fstp"),
|
tensor=dict(size=8, mode="fstp", overlap=True),
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torch.distributed import ProcessGroup
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
from internlm.utils.common import get_current_device
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
@ -148,6 +149,18 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
|
||||||
async_op=async_op)
|
async_op=async_op)
|
||||||
return output, handle
|
return output, handle
|
||||||
|
|
||||||
|
def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||||
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
|
assert input_.shape[0] % world_size == 0
|
||||||
|
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||||
|
index = check_reduce_scatter_memory_pool(size)
|
||||||
|
output = gpc.config.reduce_scatter_memory[size]['data'][index]
|
||||||
|
setattr(output, "index", index)
|
||||||
|
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
||||||
|
group=process_group,
|
||||||
|
async_op=async_op)
|
||||||
|
return output, handle
|
||||||
|
|
||||||
|
|
||||||
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
||||||
class FusedDenseFunc(torch.autograd.Function):
|
class FusedDenseFunc(torch.autograd.Function):
|
||||||
|
@ -404,12 +417,13 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
# assert hasattr(bias, "_fstp_all_reduce_str")
|
# assert hasattr(bias, "_fstp_all_reduce_str")
|
||||||
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
|
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
|
||||||
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||||
grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
|
|
||||||
|
grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
|
||||||
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
assert hasattr(weight, "_fstp_reduce_scatter_str")
|
||||||
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
|
||||||
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
|
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
|
||||||
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
assert hasattr(bias, "_fstp_reduce_scatter_str")
|
||||||
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
|
||||||
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
|
||||||
|
@ -521,3 +535,37 @@ def Silu(w1_o, w2_o):
|
||||||
|
|
||||||
|
|
||||||
Silu = torch.jit.script(Silu)
|
Silu = torch.jit.script(Silu)
|
||||||
|
|
||||||
|
def check_reduce_scatter_memory_pool(key):
|
||||||
|
|
||||||
|
return_idx = 0
|
||||||
|
|
||||||
|
# if key not in dict
|
||||||
|
if key not in gpc.config.reduce_scatter_memory:
|
||||||
|
gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []}
|
||||||
|
|
||||||
|
# if the data is empty
|
||||||
|
if len(gpc.config.reduce_scatter_memory[key]['data']) == 0:
|
||||||
|
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
|
||||||
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
|
device=get_current_device()).contiguous())
|
||||||
|
gpc.config.reduce_scatter_memory[key]['used'].append(True)
|
||||||
|
return_idx = 0
|
||||||
|
return return_idx
|
||||||
|
else: # if not empty
|
||||||
|
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']):
|
||||||
|
if used == False:
|
||||||
|
gpc.config.reduce_scatter_memory[key]['used'][index] = True
|
||||||
|
return_idx = index
|
||||||
|
return return_idx
|
||||||
|
# if the memory pool is all used
|
||||||
|
length = len(gpc.config.reduce_scatter_memory[key]['data'])
|
||||||
|
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
|
||||||
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
|
device=get_current_device()).contiguous())
|
||||||
|
gpc.config.reduce_scatter_memory[key]['used'].append(True)
|
||||||
|
return_idx = length
|
||||||
|
return return_idx
|
||||||
|
|
||||||
|
def release_reduce_scatter_memory_pool(size, index):
|
||||||
|
gpc.config.reduce_scatter_memory[size]['used'][index] = False
|
|
@ -10,7 +10,7 @@ from torch.optim import Optimizer
|
||||||
|
|
||||||
from internlm.core.context import Config, ParallelMode
|
from internlm.core.context import Config, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import split_forward_gather_backward
|
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
|
||||||
from internlm.monitor import send_alert_message
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer.store import (
|
from internlm.solver.optimizer.store import (
|
||||||
BucketStore,
|
BucketStore,
|
||||||
|
@ -353,7 +353,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
comm_handle.wait()
|
comm_handle.wait()
|
||||||
_param.grad.add_(_grad)
|
_param.grad.add_(_grad)
|
||||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||||
del _grad
|
# del _grad
|
||||||
|
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
||||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||||
|
@ -395,7 +396,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
comm_handle.wait()
|
comm_handle.wait()
|
||||||
_param.grad.add_(_grad)
|
_param.grad.add_(_grad)
|
||||||
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
# self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||||
del _grad
|
# del _grad
|
||||||
|
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
|
||||||
del self._fstp_handler.reduce_scatter_handlers[key]
|
del self._fstp_handler.reduce_scatter_handlers[key]
|
||||||
self._fstp_handler.reduce_scatter_handlers[key] = None
|
self._fstp_handler.reduce_scatter_handlers[key] = None
|
||||||
assert key in self._fstp_handler.reduce_scatter_handlers
|
assert key in self._fstp_handler.reduce_scatter_handlers
|
||||||
|
|
|
@ -51,7 +51,7 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||||
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
|
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
|
||||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||||
from internlm.train.utils import create_param_groups
|
from internlm.train.utils import create_param_groups
|
||||||
from internlm.utils.common import DummyProfile
|
from internlm.utils.common import DummyProfile, get_current_device
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
||||||
|
@ -123,7 +123,8 @@ def initialize_model():
|
||||||
mlp_ratio = gpc.config.MLP_RATIO
|
mlp_ratio = gpc.config.MLP_RATIO
|
||||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||||
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
|
||||||
size_key = [(3 * hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (hidden_size, hidden_size)]
|
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
|
size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)]
|
||||||
module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3']
|
module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3']
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
weight = {}
|
weight = {}
|
||||||
|
@ -131,21 +132,26 @@ def initialize_model():
|
||||||
if name == 'Wqkv':
|
if name == 'Wqkv':
|
||||||
weight[name] = torch.zeros((3 * hidden_size, hidden_size),
|
weight[name] = torch.zeros((3 * hidden_size, hidden_size),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device='cuda').contiguous()
|
device=get_current_device()).contiguous()
|
||||||
elif name == 'out_proj':
|
elif name == 'out_proj':
|
||||||
weight[name] = torch.zeros((hidden_size, hidden_size),
|
weight[name] = torch.zeros((hidden_size, hidden_size),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device='cuda').contiguous()
|
device=get_current_device()).contiguous()
|
||||||
elif name == 'w1' or name == 'w2':
|
elif name == 'w1' or name == 'w2':
|
||||||
weight[name] = torch.zeros((mlp_hidden_size, hidden_size),
|
weight[name] = torch.zeros((mlp_hidden_size, hidden_size),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device='cuda').contiguous()
|
device=get_current_device()).contiguous()
|
||||||
else:
|
else:
|
||||||
weight[name] = torch.zeros((hidden_size, mlp_hidden_size),
|
weight[name] = torch.zeros((hidden_size, mlp_hidden_size),
|
||||||
dtype=gpc.config.model.get("dtype", torch.half),
|
dtype=gpc.config.model.get("dtype", torch.half),
|
||||||
device='cuda').contiguous()
|
device=get_current_device()).contiguous()
|
||||||
block_memory[i] = weight
|
block_memory[i] = weight
|
||||||
|
reduce_scatter_memory = {}
|
||||||
|
for key in size_key:
|
||||||
|
reduce_scatter_memory[key] = {'data': [], 'used': []}
|
||||||
|
|
||||||
gpc.config.block_memory = block_memory
|
gpc.config.block_memory = block_memory
|
||||||
|
gpc.config.reduce_scatter_memory = reduce_scatter_memory
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
1
train.py
1
train.py
|
@ -299,6 +299,7 @@ def main(args):
|
||||||
|
|
||||||
if gpc.config.fstp_handler is not None:
|
if gpc.config.fstp_handler is not None:
|
||||||
gpc.config.fstp_handler.zero_const_pool = {}
|
gpc.config.fstp_handler.zero_const_pool = {}
|
||||||
|
gpc.config.fstp_handler.reduce_scatter_memory = {}
|
||||||
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue