mirror of https://github.com/hpcaitech/ColossalAI
support shardinit option to avoid OPT OOM initializing problem (#3037)
Co-authored-by: poe <poe@nemoramo>pull/3056/head
parent
29386a54e6
commit
2ef855c798
|
@ -4,10 +4,17 @@ export MEMCAP=${MEMCAP:-0}
|
||||||
# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
|
# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
|
||||||
export MODEL=${MODEL:-"125m"}
|
export MODEL=${MODEL:-"125m"}
|
||||||
export GPUNUM=${GPUNUM:-1}
|
export GPUNUM=${GPUNUM:-1}
|
||||||
|
export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"}
|
||||||
|
|
||||||
# make directory for logs
|
# make directory for logs
|
||||||
mkdir -p ./logs
|
mkdir -p ./logs
|
||||||
|
|
||||||
|
if [ ${USE_SHARD_INIT} = "true" ]; then
|
||||||
|
USE_SHARD_INIT="--shardinit"
|
||||||
|
else
|
||||||
|
USE_SHARD_INIT=""
|
||||||
|
fi
|
||||||
|
|
||||||
export MODLE_PATH="facebook/opt-${MODEL}"
|
export MODLE_PATH="facebook/opt-${MODEL}"
|
||||||
|
|
||||||
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
||||||
|
@ -17,4 +24,5 @@ torchrun \
|
||||||
train_gemini_opt.py \
|
train_gemini_opt.py \
|
||||||
--mem_cap ${MEMCAP} \
|
--mem_cap ${MEMCAP} \
|
||||||
--model_name_or_path ${MODLE_PATH} \
|
--model_name_or_path ${MODLE_PATH} \
|
||||||
|
${USE_SHARD_INIT} \
|
||||||
--batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
|
--batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
|
||||||
|
|
|
@ -39,6 +39,8 @@ from colossalai.nn.parallel import GeminiDDP
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
|
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||||
|
|
||||||
|
|
||||||
def get_data(batch_size, seq_len, vocab_size):
|
def get_data(batch_size, seq_len, vocab_size):
|
||||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||||
|
@ -102,6 +104,11 @@ def parse_args():
|
||||||
help="Model type to use if training from scratch.",
|
help="Model type to use if training from scratch.",
|
||||||
choices=MODEL_TYPES,
|
choices=MODEL_TYPES,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shardinit",
|
||||||
|
action="store_true",
|
||||||
|
help="Initialize the model with tensor parallel",
|
||||||
|
)
|
||||||
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
||||||
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -159,16 +166,30 @@ def main():
|
||||||
else:
|
else:
|
||||||
init_dev = get_current_device()
|
init_dev = get_current_device()
|
||||||
|
|
||||||
|
# shard init prameters
|
||||||
|
if args.shardinit:
|
||||||
|
logger.info("Sharding initialization !", ranks=[0])
|
||||||
|
else:
|
||||||
|
logger.info("Skipping sharding initialization", ranks=[0])
|
||||||
|
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
||||||
|
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
||||||
# currently, there has a bug in pretrained opt-13b
|
# currently, there has a bug in pretrained opt-13b
|
||||||
# we can not import it until huggingface fix it
|
# we can not import it until huggingface fix it
|
||||||
logger.info("Train a new model from scratch", ranks=[0])
|
logger.info("Train a new model from scratch", ranks=[0])
|
||||||
with ColoInitContext(device=init_dev, dtype=torch.half):
|
with ColoInitContext(device=init_dev, dtype=torch.half,
|
||||||
|
default_dist_spec=default_dist_spec,
|
||||||
|
default_pg=shard_pg):
|
||||||
model = OPTForCausalLM(config)
|
model = OPTForCausalLM(config)
|
||||||
else:
|
else:
|
||||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||||
with ColoInitContext(device=init_dev, dtype=torch.half):
|
with ColoInitContext(device=init_dev, dtype=torch.half,
|
||||||
|
default_dist_spec=default_dist_spec,
|
||||||
|
default_pg=shard_pg):
|
||||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -179,7 +200,8 @@ def main():
|
||||||
|
|
||||||
numel = sum([p.numel() for p in model.parameters()])
|
numel = sum([p.numel() for p in model.parameters()])
|
||||||
PLACEMENT_POLICY = 'cpu'
|
PLACEMENT_POLICY = 'cpu'
|
||||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
|
||||||
|
pin_memory=True, strict_ddp_mode=args.shardinit)
|
||||||
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
|
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
|
||||||
|
|
||||||
SEQ_LEN = 1024
|
SEQ_LEN = 1024
|
||||||
|
|
Loading…
Reference in New Issue