From df1d6dc553b6b1baf05c5a79716a52476f71948b Mon Sep 17 00:00:00 2001 From: ZijianYY <119492445+ZijianYY@users.noreply.github.com> Date: Tue, 3 Jan 2023 17:49:00 +0800 Subject: [PATCH] [examples] using args and combining two versions for PaLM (#2284) --- examples/language/palm/palm_config.py | 6 -- examples/language/palm/run.sh | 12 ++- examples/language/palm/train.py | 123 ++++++++++++++++++-------- 3 files changed, 97 insertions(+), 44 deletions(-) delete mode 100644 examples/language/palm/palm_config.py diff --git a/examples/language/palm/palm_config.py b/examples/language/palm/palm_config.py deleted file mode 100644 index 9fb9a900f..000000000 --- a/examples/language/palm/palm_config.py +++ /dev/null @@ -1,6 +0,0 @@ -SEQ_LENGTH = 1024 -BATCH_SIZE = 4 -NUM_EPOCHS = 4 -TPDEGREE = 2 -USE_SHARD_INIT = False -placement = 'cpu' \ No newline at end of file diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh index 700401786..4aa868953 100644 --- a/examples/language/palm/run.sh +++ b/examples/language/palm/run.sh @@ -1 +1,11 @@ -env OMP_NUM_THREADS=12 torchrun --nproc_per_node 4 --master_port 29501 train.py --config palm_config.py +# distplan in ["colossalai", "pytorch"] +export DISTPAN="colossalai" + +# The following options only valid when DISTPAN="colossalai" +export TPDEGREE=1 +export GPUNUM=1 +export PLACEMENT='cpu' +export USE_SHARD_INIT=False +export BATCH_SIZE=4 + +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train_new.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log \ No newline at end of file diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 135badba4..89b4e058f 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -21,19 +21,51 @@ from colossalai.utils.model.colo_init_context import ColoInitContext # constants -NUM_BATCHES = int(20) -BATCH_SIZE = 4 +NUM_BATCHES = int(1000) GRADIENT_ACCUMULATE_EVERY = 1 LEARNING_RATE = 2e-4 VALIDATE_EVERY = 100 GENERATE_EVERY = 500 GENERATE_LENGTH = 512 SEQ_LEN = 1024 -TPDEGREE = 1 -USE_SHARD_INIT = False -placement = 'cpu' +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='colossalai', + help="The distributed plan [colossalai, pytorch].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + type=bool, + default=False, + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch size per DP group of training.", + ) + args = parser.parse_args() + return args + # helpers def cycle(loader): while True: @@ -73,22 +105,11 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: return model -# instantiate GPT-like decoder model - -parser = colossalai.get_default_parser() -args = parser.parse_args() +args = parse_args() +if args.distplan not in ["colossalai", "pytorch"]: + raise TypeError(f"{args.distplan} is error") disable_existing_loggers() -colossalai.launch_from_torch(config=args.config, seed=42) - -# instantiate GPT-like decoder model - -default_pg = ProcessGroup(tp_degree=TPDEGREE) -default_dist_spec = ShardSpec([-1], [TPDEGREE]) if USE_SHARD_INIT else None -ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) - -with ctx: - model = PaLM(num_tokens=256, dim=512, depth=8) - model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) +colossalai.launch_from_torch(config={}) with gzip.open("./data/enwik8.gz") as file: X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) @@ -114,34 +135,62 @@ class TextSamplerDataset(Dataset): train_dataset = TextSamplerDataset(data_train, SEQ_LEN) val_dataset = TextSamplerDataset(data_val, SEQ_LEN) -train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE)) -val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE)) +train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size)) +val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size)) -#tensor_parallelize(model, pg) +if args.distplan == "colossalai": + # instantiate GPT-like decoder model -pg = default_pg -model = gemini_zero_dpp(model, pg, placement) + default_pg = ProcessGroup(tp_degree=args.tp_degree) + default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None + ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) + + with ctx: + model = PaLM(num_tokens=256, dim=512, depth=8) + model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) + + pg = default_pg + #tensor_parallelize(model, pg) + model = gemini_zero_dpp(model, pg, args.placement) + + #optimizer + + #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) + optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) +else: + model = PaLM(num_tokens=256, dim=512, depth=8) + model = AutoregressiveWrapper(model, max_seq_len=2048) + model.cuda() + optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) -#optimizer -optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) # training model.train() for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): - optimizer.zero_grad() + if args.distplan == "colossalai": + optimizer.zero_grad() - loss = model(next(train_loader)) - # loss.backward() - optimizer.backward(loss) + loss = model(next(train_loader)) + # loss.backward() + optimizer.backward(loss) - print(f"training loss: {loss.item()}") - torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) - # optim.step() - # optim.zero_grad() - optimizer.step() + print(f"training loss: {loss.item()}") + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + # optim.step() + # optim.zero_grad() + optimizer.step() + else: + for __ in range(GRADIENT_ACCUMULATE_EVERY): + loss = model(next(train_loader)) + loss.backward() + + print(f"training loss: {loss.item()}") + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optim.step() + optim.zero_grad() # TODO # if i % VALIDATE_EVERY == 0: @@ -158,4 +207,4 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): # sample = model.generate(inp[None, ...], GENERATE_LENGTH) # output_str = decode_tokens(sample[0]) - # print(output_str) + # print(output_str) \ No newline at end of file