Browse Source

support shardinit option to avoid OPT OOM initializing problem (#3037)

Co-authored-by: poe <poe@nemoramo>
pull/3056/head
ramos 2 years ago committed by GitHub
parent
commit
2ef855c798
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      examples/language/opt/run_gemini.sh
  2. 28
      examples/language/opt/train_gemini_opt.py

8
examples/language/opt/run_gemini.sh

@ -4,10 +4,17 @@ export MEMCAP=${MEMCAP:-0}
# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b` # Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
export MODEL=${MODEL:-"125m"} export MODEL=${MODEL:-"125m"}
export GPUNUM=${GPUNUM:-1} export GPUNUM=${GPUNUM:-1}
export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"}
# make directory for logs # make directory for logs
mkdir -p ./logs mkdir -p ./logs
if [ ${USE_SHARD_INIT} = "true" ]; then
USE_SHARD_INIT="--shardinit"
else
USE_SHARD_INIT=""
fi
export MODLE_PATH="facebook/opt-${MODEL}" export MODLE_PATH="facebook/opt-${MODEL}"
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 # HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
@ -17,4 +24,5 @@ torchrun \
train_gemini_opt.py \ train_gemini_opt.py \
--mem_cap ${MEMCAP} \ --mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE_PATH} \ --model_name_or_path ${MODLE_PATH} \
${USE_SHARD_INIT} \
--batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log --batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log

28
examples/language/opt/train_gemini_opt.py

@ -39,6 +39,8 @@ from colossalai.nn.parallel import GeminiDDP
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ProcessGroup, ShardSpec
def get_data(batch_size, seq_len, vocab_size): def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
@ -102,6 +104,11 @@ def parse_args():
help="Model type to use if training from scratch.", help="Model type to use if training from scratch.",
choices=MODEL_TYPES, choices=MODEL_TYPES,
) )
parser.add_argument(
"--shardinit",
action="store_true",
help="Initialize the model with tensor parallel",
)
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
args = parser.parse_args() args = parser.parse_args()
@ -159,16 +166,30 @@ def main():
else: else:
init_dev = get_current_device() init_dev = get_current_device()
# shard init prameters
if args.shardinit:
logger.info("Sharding initialization !", ranks=[0])
else:
logger.info("Skipping sharding initialization", ranks=[0])
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
# build model # build model
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
# currently, there has a bug in pretrained opt-13b # currently, there has a bug in pretrained opt-13b
# we can not import it until huggingface fix it # we can not import it until huggingface fix it
logger.info("Train a new model from scratch", ranks=[0]) logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half): with ColoInitContext(device=init_dev, dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config) model = OPTForCausalLM(config)
else: else:
logger.info("Finetune a pre-trained model", ranks=[0]) logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half): with ColoInitContext(device=init_dev, dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
@ -179,7 +200,8 @@ def main():
numel = sum([p.numel() for p in model.parameters()]) numel = sum([p.numel() for p in model.parameters()])
PLACEMENT_POLICY = 'cpu' PLACEMENT_POLICY = 'cpu'
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
pin_memory=True, strict_ddp_mode=args.shardinit)
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
SEQ_LEN = 1024 SEQ_LEN = 1024

Loading…
Cancel
Save