mirror of https://github.com/hpcaitech/ColossalAI
support shardinit option to avoid OPT OOM initializing problem (#3037)
Co-authored-by: poe <poe@nemoramo>pull/3056/head
parent
29386a54e6
commit
2ef855c798
|
@ -4,10 +4,17 @@ export MEMCAP=${MEMCAP:-0}
|
|||
# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
|
||||
export MODEL=${MODEL:-"125m"}
|
||||
export GPUNUM=${GPUNUM:-1}
|
||||
export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"}
|
||||
|
||||
# make directory for logs
|
||||
mkdir -p ./logs
|
||||
|
||||
if [ ${USE_SHARD_INIT} = "true" ]; then
|
||||
USE_SHARD_INIT="--shardinit"
|
||||
else
|
||||
USE_SHARD_INIT=""
|
||||
fi
|
||||
|
||||
export MODLE_PATH="facebook/opt-${MODEL}"
|
||||
|
||||
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
||||
|
@ -17,4 +24,5 @@ torchrun \
|
|||
train_gemini_opt.py \
|
||||
--mem_cap ${MEMCAP} \
|
||||
--model_name_or_path ${MODLE_PATH} \
|
||||
${USE_SHARD_INIT} \
|
||||
--batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
|
||||
|
|
|
@ -39,6 +39,8 @@ from colossalai.nn.parallel import GeminiDDP
|
|||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
|
||||
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
|
@ -102,6 +104,11 @@ def parse_args():
|
|||
help="Model type to use if training from scratch.",
|
||||
choices=MODEL_TYPES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
action="store_true",
|
||||
help="Initialize the model with tensor parallel",
|
||||
)
|
||||
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
||||
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
||||
args = parser.parse_args()
|
||||
|
@ -159,16 +166,30 @@ def main():
|
|||
else:
|
||||
init_dev = get_current_device()
|
||||
|
||||
# shard init prameters
|
||||
if args.shardinit:
|
||||
logger.info("Sharding initialization !", ranks=[0])
|
||||
else:
|
||||
logger.info("Skipping sharding initialization", ranks=[0])
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||
|
||||
# build model
|
||||
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
||||
# currently, there has a bug in pretrained opt-13b
|
||||
# we can not import it until huggingface fix it
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
with ColoInitContext(device=init_dev, dtype=torch.half):
|
||||
with ColoInitContext(device=init_dev, dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ColoInitContext(device=init_dev, dtype=torch.half):
|
||||
with ColoInitContext(device=init_dev, dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
|
@ -179,7 +200,8 @@ def main():
|
|||
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
PLACEMENT_POLICY = 'cpu'
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
|
||||
pin_memory=True, strict_ddp_mode=args.shardinit)
|
||||
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
|
||||
|
||||
SEQ_LEN = 1024
|
||||
|
|
Loading…
Reference in New Issue