diff --git a/examples/language/opt/run_gemini.sh b/examples/language/opt/run_gemini.sh index 92fd481c5..73f231292 100644 --- a/examples/language/opt/run_gemini.sh +++ b/examples/language/opt/run_gemini.sh @@ -4,10 +4,17 @@ export MEMCAP=${MEMCAP:-0} # Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b` export MODEL=${MODEL:-"125m"} export GPUNUM=${GPUNUM:-1} +export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"} # make directory for logs mkdir -p ./logs +if [ ${USE_SHARD_INIT} = "true" ]; then + USE_SHARD_INIT="--shardinit" +else + USE_SHARD_INIT="" +fi + export MODLE_PATH="facebook/opt-${MODEL}" # HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 @@ -17,4 +24,5 @@ torchrun \ train_gemini_opt.py \ --mem_cap ${MEMCAP} \ --model_name_or_path ${MODLE_PATH} \ + ${USE_SHARD_INIT} \ --batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py index 64426ba42..1546b31ba 100755 --- a/examples/language/opt/train_gemini_opt.py +++ b/examples/language/opt/train_gemini_opt.py @@ -39,6 +39,8 @@ from colossalai.nn.parallel import GeminiDDP from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.tensor import ProcessGroup, ShardSpec + def get_data(batch_size, seq_len, vocab_size): input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) @@ -102,6 +104,11 @@ def parse_args(): help="Model type to use if training from scratch.", choices=MODEL_TYPES, ) + parser.add_argument( + "--shardinit", + action="store_true", + help="Initialize the model with tensor parallel", + ) parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") args = parser.parse_args() @@ -159,16 +166,30 @@ def main(): else: init_dev = get_current_device() + # shard init prameters + if args.shardinit: + logger.info("Sharding initialization !", ranks=[0]) + else: + logger.info("Skipping sharding initialization", ranks=[0]) + + world_size = torch.distributed.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None + default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None + # build model if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': # currently, there has a bug in pretrained opt-13b # we can not import it until huggingface fix it logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev, dtype=torch.half): + with ColoInitContext(device=init_dev, dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): model = OPTForCausalLM(config) else: logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev, dtype=torch.half): + with ColoInitContext(device=init_dev, dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): model = OPTForCausalLM.from_pretrained(args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, @@ -179,7 +200,8 @@ def main(): numel = sum([p.numel() for p in model.parameters()]) PLACEMENT_POLICY = 'cpu' - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, + pin_memory=True, strict_ddp_mode=args.shardinit) optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) SEQ_LEN = 1024