diff --git a/examples/language/palm/palm_config.py b/examples/language/palm/palm_config.py new file mode 100644 index 000000000..9fb9a900f --- /dev/null +++ b/examples/language/palm/palm_config.py @@ -0,0 +1,6 @@ +SEQ_LENGTH = 1024 +BATCH_SIZE = 4 +NUM_EPOCHS = 4 +TPDEGREE = 2 +USE_SHARD_INIT = False +placement = 'cpu' \ No newline at end of file diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py index 105991967..aaf5fd050 100644 --- a/examples/language/palm/palm_pytorch/palm_pytorch.py +++ b/examples/language/palm/palm_pytorch/palm_pytorch.py @@ -47,7 +47,9 @@ class RotaryEmbedding(nn.Module): def forward(self, max_seq_len, *, device): seq = torch.arange(max_seq_len, device=device) #freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) - freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + #freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq) + freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j)) return torch.cat((freqs, freqs), dim=-1) diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh new file mode 100644 index 000000000..154d037d5 --- /dev/null +++ b/examples/language/palm/run.sh @@ -0,0 +1 @@ +env OMP_NUM_THREADS=12 torchrun --nproc_per_node 8 --master_port 29501 train.py --config palm_config.py \ No newline at end of file diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index ba243e507..f8e58eae6 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -9,6 +9,16 @@ from palm_pytorch import PaLM from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from torch.nn import functional as F from torch.utils.data import DataLoader, Dataset +from packaging import version + +import colossalai +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import MultiTimer, get_current_device +from colossalai.nn.parallel import ZeroDDP +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP +from colossalai.logging import disable_existing_loggers, get_dist_logger # constants @@ -20,6 +30,9 @@ VALIDATE_EVERY = 100 GENERATE_EVERY = 500 GENERATE_LENGTH = 512 SEQ_LEN = 1024 +TPDEGREE = 2 +USE_SHARD_INIT = False +placement = 'cpu' # helpers @@ -37,16 +50,55 @@ def decode_token(token): def decode_tokens(tokens): return "".join(list(map(decode_token, tokens))) +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + cai_version = colossalai.__version__ + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + gemini_manager = GeminiManager(placememt_policy, chunk_manager) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placememt_policy)) + model = ZeroDDP(model, gemini_manager) + else: + raise NotImplemented(f"CAI version {cai_version} is not supported") + return model + +# instantiate GPT-like decoder model + +parser = colossalai.get_default_parser() +args = parser.parse_args() +disable_existing_loggers() +colossalai.launch_from_torch(config=args.config, seed=42) + # instantiate GPT-like decoder model -model = PaLM(num_tokens=256, dim=512, depth=8) +default_pg = ProcessGroup(tp_degree=TPDEGREE) +default_dist_spec = ShardSpec([-1], [TPDEGREE]) if USE_SHARD_INIT else None +ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) +with ctx: + model = PaLM(num_tokens=256,dim=512,depth=8) + model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) + model.cuda() -model = AutoregressiveWrapper(model, max_seq_len=2048) -model.cuda() # prepare enwik8 data +# model = PaLM(num_tokens=256, dim=512, depth=8) + +# model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) +# model.cuda() + with gzip.open("./data/enwik8.gz") as file: X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) trX, vaX = np.split(X, [int(90e6)]) @@ -74,9 +126,20 @@ val_dataset = TextSamplerDataset(data_val, SEQ_LEN) train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE)) val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE)) -# optimizer +#tensor_parallelize(model, pg) + +pg = default_pg +# model = GeminiDDP(model, +# device=get_current_device(), +# placement_policy="auto", +# pin_memory=True, +# search_range_mb=32) +model = gemini_zero_dpp(model, pg, placement) + +#optimizer -optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) +optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) +#optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) # training @@ -89,8 +152,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): print(f"training loss: {loss.item()}") torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) - optim.step() - optim.zero_grad() + # optim.step() + # optim.zero_grad() + optimizer.step() + optimizer.zero_grad() if i % VALIDATE_EVERY == 0: model.eval()