diff --git a/examples/tutorial/handson1/README.md b/examples/tutorial/handson1/README.md new file mode 100644 index 000000000..dcbdc1e00 --- /dev/null +++ b/examples/tutorial/handson1/README.md @@ -0,0 +1,27 @@ +# Handson 1: Multi-dimensional Parallelism with Colossal-AI + + +## Install Colossal-AI and other dependencies + +```bash +sh install.sh +``` + + +## Prepare Dataset + +We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + + +## Run on 2*2 device mesh + +Current configuration setting on `config.py` is TP=2, PP=2. + +```bash +colossalai run --nproc_per_node 4 train.py --config config.py +``` \ No newline at end of file diff --git a/examples/tutorial/handson1/config.py b/examples/tutorial/handson1/config.py new file mode 100644 index 000000000..2450ab1c7 --- /dev/null +++ b/examples/tutorial/handson1/config.py @@ -0,0 +1,36 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 3 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 512 +DEPTH = 4 +NUM_HEADS = 4 +MLP_RATIO = 2 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] diff --git a/examples/tutorial/handson1/install.sh b/examples/tutorial/handson1/install.sh new file mode 100644 index 000000000..252f6bcca --- /dev/null +++ b/examples/tutorial/handson1/install.sh @@ -0,0 +1,4 @@ +pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 +pip install colossalai==0.1.10+torch1.12cu11.3 -f https://release.colossalai.org +pip install titans +colossalai check -i \ No newline at end of file diff --git a/examples/tutorial/handson1/train.py b/examples/tutorial/handson1/train.py new file mode 100644 index 000000000..1fb34d806 --- /dev/null +++ b/examples/tutorial/handson1/train.py @@ -0,0 +1,116 @@ +import os +import colossalai +import torch + +from tqdm import tqdm +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.utils import is_using_pp, get_dataloader +from colossalai.pipeline.pipelinable import PipelinableContext +from titans.model.vit.vit import _create_vit_model +from titans.dataloader.cifar10 import build_cifar + + +def main(): + # initialize distributed setting + parser = colossalai.get_default_parser() + args = parser.parse_args() + + # launch from torch + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + use_pipeline = is_using_pp() + + # create model + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + + if use_pipeline: + pipelinable = PipelinableContext() + with pipelinable: + model = _create_vit_model(**model_kwargs) + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition( + 1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + else: + model = _create_vit_model(**model_kwargs) + + # count number of parameters + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 + else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") + + # create dataloaders + root = os.environ.get('DATA', '../data/cifar10') + train_dataloader, test_dataloader = build_cifar( + gpc.config.BATCH_SIZE, root, pad_if_needed=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters( + ), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + data_iter = iter(train_dataloader) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/handson2/README.md b/examples/tutorial/handson2/README.md new file mode 100644 index 000000000..03ab7a1b4 --- /dev/null +++ b/examples/tutorial/handson2/README.md @@ -0,0 +1,20 @@ +# Handson 2: Sequence Parallelism with BERT + + +## Prepare Dataset + +We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + + +## Run on 2*2 device mesh + +Current configuration setting on `config.py` is TP=2, PP=2. + +```bash +colossalai run --nproc_per_node 4 train.py --config config.py +``` \ No newline at end of file diff --git a/examples/tutorial/handson2/config.py b/examples/tutorial/handson2/config.py new file mode 100644 index 000000000..f242dac71 --- /dev/null +++ b/examples/tutorial/handson2/config.py @@ -0,0 +1,35 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 256 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 3 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 512 +DEPTH = 4 +NUM_HEADS = 4 +MLP_RATIO = 2 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +# parallel setting +TENSOR_PARALLEL_SIZE = 1 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + tensor=dict(size=4, mode='sequence') +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] diff --git a/examples/tutorial/handson2/train.py b/examples/tutorial/handson2/train.py new file mode 100644 index 000000000..1fb34d806 --- /dev/null +++ b/examples/tutorial/handson2/train.py @@ -0,0 +1,116 @@ +import os +import colossalai +import torch + +from tqdm import tqdm +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.utils import is_using_pp, get_dataloader +from colossalai.pipeline.pipelinable import PipelinableContext +from titans.model.vit.vit import _create_vit_model +from titans.dataloader.cifar10 import build_cifar + + +def main(): + # initialize distributed setting + parser = colossalai.get_default_parser() + args = parser.parse_args() + + # launch from torch + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + use_pipeline = is_using_pp() + + # create model + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + + if use_pipeline: + pipelinable = PipelinableContext() + with pipelinable: + model = _create_vit_model(**model_kwargs) + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition( + 1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + else: + model = _create_vit_model(**model_kwargs) + + # count number of parameters + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 + else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") + + # create dataloaders + root = os.environ.get('DATA', '../data/cifar10') + train_dataloader, test_dataloader = build_cifar( + gpc.config.BATCH_SIZE, root, pad_if_needed=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters( + ), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + data_iter = iter(train_dataloader) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/handson3/README.md similarity index 88% rename from examples/tutorial/auto_parallel/README.md rename to examples/tutorial/handson3/README.md index 93ce29e11..eb38146ad 100644 --- a/examples/tutorial/auto_parallel/README.md +++ b/examples/tutorial/handson3/README.md @@ -1,4 +1,4 @@ -# Train ResNet on CIFAR10 with auto_parallel +# Handson 3: Auto-Parallelism with ResNet ## Prepare Dataset diff --git a/examples/tutorial/auto_parallel/auto_ckpt_demo.ipynb b/examples/tutorial/handson3/auto_ckpt_demo.ipynb similarity index 100% rename from examples/tutorial/auto_parallel/auto_ckpt_demo.ipynb rename to examples/tutorial/handson3/auto_ckpt_demo.ipynb diff --git a/examples/tutorial/auto_parallel/auto_parallel_demo.py b/examples/tutorial/handson3/auto_parallel_demo.py similarity index 100% rename from examples/tutorial/auto_parallel/auto_parallel_demo.py rename to examples/tutorial/handson3/auto_parallel_demo.py diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/handson3/bench_utils.py similarity index 100% rename from examples/tutorial/auto_parallel/bench_utils.py rename to examples/tutorial/handson3/bench_utils.py diff --git a/examples/tutorial/handson4/README.md b/examples/tutorial/handson4/README.md new file mode 100644 index 000000000..e55e3bd21 --- /dev/null +++ b/examples/tutorial/handson4/README.md @@ -0,0 +1,17 @@ +# Handson 4: Comparison of Large Batch Training Optimization + +## Prepare Dataset + +We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default. +If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command. + +```bash +export DATA=/path/to/data +``` + + +## Run on 2*2 device mesh + +```bash +colossalai run --nproc_per_node 4 train.py --config config.py +``` \ No newline at end of file diff --git a/examples/tutorial/handson4/config.py b/examples/tutorial/handson4/config.py new file mode 100644 index 000000000..e019154e4 --- /dev/null +++ b/examples/tutorial/handson4/config.py @@ -0,0 +1,36 @@ +from colossalai.amp import AMP_TYPE + +# hyperparameters +# BATCH_SIZE is as per GPU +# global batch size = BATCH_SIZE x data parallel size +BATCH_SIZE = 512 +LEARNING_RATE = 3e-3 +WEIGHT_DECAY = 0.3 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 3 + +# model config +IMG_SIZE = 224 +PATCH_SIZE = 16 +HIDDEN_SIZE = 512 +DEPTH = 4 +NUM_HEADS = 4 +MLP_RATIO = 2 +NUM_CLASSES = 1000 +CHECKPOINT = False +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token + +# parallel setting +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +parallel = dict( + pipeline=2, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.NAIVE) +clip_grad_norm = 1.0 + +# pipeline config +NUM_MICRO_BATCHES = parallel['pipeline'] diff --git a/examples/tutorial/handson4/train.py b/examples/tutorial/handson4/train.py new file mode 100644 index 000000000..ffbc8f302 --- /dev/null +++ b/examples/tutorial/handson4/train.py @@ -0,0 +1,117 @@ +import os +import colossalai +import torch + +from tqdm import tqdm +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import Lars, Lamb +from colossalai.utils import is_using_pp, get_dataloader +from colossalai.pipeline.pipelinable import PipelinableContext +from titans.model.vit.vit import _create_vit_model +from titans.dataloader.cifar10 import build_cifar + + +def main(): + # initialize distributed setting + parser = colossalai.get_default_parser() + args = parser.parse_args() + + # launch from torch + colossalai.launch_from_torch(config=args.config) + + # get logger + logger = get_dist_logger() + logger.info("initialized distributed environment", ranks=[0]) + + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + use_pipeline = is_using_pp() + + # create model + model_kwargs = dict(img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method='jax', + checkpoint=gpc.config.CHECKPOINT) + + if use_pipeline: + pipelinable = PipelinableContext() + with pipelinable: + model = _create_vit_model(**model_kwargs) + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition( + 1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + else: + model = _create_vit_model(**model_kwargs) + + # count number of parameters + total_numel = 0 + for p in model.parameters(): + total_numel += p.numel() + if not gpc.is_initialized(ParallelMode.PIPELINE): + pipeline_stage = 0 + else: + pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}") + + # create dataloaders + root = os.environ.get('DATA', '../data/cifar10') + train_dataloader, test_dataloader = build_cifar( + gpc.config.BATCH_SIZE, root, pad_if_needed=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, + weight_decay=gpc.config.WEIGHT_DECAY) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS, + warmup_steps=gpc.config.WARMUP_EPOCHS) + + # initialize + engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader) + + logger.info("Engine is built", ranks=[0]) + + data_iter = iter(train_dataloader) + + for epoch in range(gpc.config.NUM_EPOCHS): + # training + engine.train() + + if gpc.get_global_rank() == 0: + description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + progress = tqdm(range(len(train_dataloader)), desc=description) + else: + progress = range(len(train_dataloader)) + for _ in progress: + engine.zero_grad() + engine.execute_schedule(data_iter, return_output_label=False) + engine.step() + lr_scheduler.step() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/handson5/README.md b/examples/tutorial/handson5/README.md new file mode 100644 index 000000000..d531806b3 --- /dev/null +++ b/examples/tutorial/handson5/README.md @@ -0,0 +1 @@ +# Handson 5: Fine-tuning and Serving for OPT from Hugging Face diff --git a/examples/tutorial/handson5/inference/README.md b/examples/tutorial/handson5/inference/README.md new file mode 100644 index 000000000..265608674 --- /dev/null +++ b/examples/tutorial/handson5/inference/README.md @@ -0,0 +1,77 @@ +# Overview + +This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI. + +It supports tensor parallelism, batching and caching. + +# How to run + +Run OPT-125M: +```shell +python opt_fastapi.py opt-125m +``` + +It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs. + +## Configure + +### Configure model +```shell +python opt_fastapi.py +``` +Available models: opt-125m, opt-6.7b, opt-30b, opt-175b. + +### Configure tensor parallelism +```shell +python opt_fastapi.py --tp +``` +The `` can be an integer in `[1, #GPUs]`. Default `1`. + +### Configure checkpoint +```shell +python opt_fastapi.py --checkpoint +``` +The `` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded. + +### Configure queue +```shell +python opt_fastapi.py --queue_size +``` +The `` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406). + +### Configure bathcing +```shell +python opt_fastapi.py --max_batch_size +``` +The `` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value. + +Note that the batch size is not always equal to ``, as some consecutive requests may not be batched. + +### Configure caching +```shell +python opt_fastapi.py --cache_size --cache_list_size +``` +This will cache `` unique requests. And for each unique request, it cache `` different results. A random result will be returned if the cache is hit. + +The `` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `` can be an integer in `[1, MAXINT]`. + +### Other configurations +```shell +python opt_fastapi.py -h +``` + +# How to benchmark +```shell +cd benchmark +locust +``` + +Then open the web interface link which is on your console. + +# Pre-process pre-trained weights + +## OPT-66B +See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py). + +## OPT-175B +See [script/process-opt-175b](./script/process-opt-175b/). \ No newline at end of file diff --git a/examples/tutorial/handson5/inference/batch.py b/examples/tutorial/handson5/inference/batch.py new file mode 100644 index 000000000..1a0876ca8 --- /dev/null +++ b/examples/tutorial/handson5/inference/batch.py @@ -0,0 +1,59 @@ +import torch +from typing import List, Deque, Tuple, Hashable, Any +from energonai import BatchManager, SubmitEntry, TaskEntry + + +class BatchManagerForGeneration(BatchManager): + def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.pad_token_id = pad_token_id + + def _left_padding(self, batch_inputs): + max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) + outputs = {'input_ids': [], 'attention_mask': []} + for inputs in batch_inputs: + input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + padding_len = max_len - len(input_ids) + input_ids = [self.pad_token_id] * padding_len + input_ids + attention_mask = [0] * padding_len + attention_mask + outputs['input_ids'].append(input_ids) + outputs['attention_mask'].append(attention_mask) + for k in outputs: + outputs[k] = torch.tensor(outputs[k]) + return outputs, max_len + + @staticmethod + def _make_batch_key(entry: SubmitEntry) -> tuple: + data = entry.data + return (data['top_k'], data['top_p'], data['temperature']) + + def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: + entry = q.popleft() + uids = [entry.uid] + batch = [entry.data] + while len(batch) < self.max_batch_size: + if len(q) == 0: + break + if self._make_batch_key(entry) != self._make_batch_key(q[0]): + break + if q[0].data['max_tokens'] > entry.data['max_tokens']: + break + e = q.popleft() + batch.append(e.data) + uids.append(e.uid) + inputs, max_len = self._left_padding(batch) + trunc_lens = [] + for data in batch: + trunc_lens.append(max_len + data['max_tokens']) + inputs['top_k'] = entry.data['top_k'] + inputs['top_p'] = entry.data['top_p'] + inputs['temperature'] = entry.data['temperature'] + inputs['max_tokens'] = max_len + entry.data['max_tokens'] + return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} + + def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: + retval = [] + for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens): + retval.append((uid, output[:trunc_len])) + return retval diff --git a/examples/tutorial/handson5/inference/benchmark/locustfile.py b/examples/tutorial/handson5/inference/benchmark/locustfile.py new file mode 100644 index 000000000..4d829e5d8 --- /dev/null +++ b/examples/tutorial/handson5/inference/benchmark/locustfile.py @@ -0,0 +1,15 @@ +from locust import HttpUser, task +from json import JSONDecodeError + + +class GenerationUser(HttpUser): + @task + def generate(self): + prompt = 'Question: What is the longest river on the earth? Answer:' + for i in range(4, 9): + data = {'max_tokens': 2**i, 'prompt': prompt} + with self.client.post('/generation', json=data, catch_response=True) as response: + if response.status_code in (200, 406): + response.success() + else: + response.failure('Response wrong') diff --git a/examples/tutorial/handson5/inference/cache.py b/examples/tutorial/handson5/inference/cache.py new file mode 100644 index 000000000..30febc44f --- /dev/null +++ b/examples/tutorial/handson5/inference/cache.py @@ -0,0 +1,64 @@ +from collections import OrderedDict +from threading import Lock +from contextlib import contextmanager +from typing import List, Any, Hashable, Dict + + +class MissCacheError(Exception): + pass + + +class ListCache: + def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None: + """Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied. + When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed. + + Args: + cache_size (int): Max size for LRU cache. + list_size (int): Value list size. + fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to []. + """ + self.cache_size = cache_size + self.list_size = list_size + self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict() + self.fixed_cache: Dict[Hashable, List[Any]] = {} + for key in fixed_keys: + self.fixed_cache[key] = [] + self._lock = Lock() + + def get(self, key: Hashable) -> List[Any]: + with self.lock(): + if key in self.fixed_cache: + l = self.fixed_cache[key] + if len(l) >= self.list_size: + return l + elif key in self.cache: + self.cache.move_to_end(key) + l = self.cache[key] + if len(l) >= self.list_size: + return l + raise MissCacheError() + + def add(self, key: Hashable, value: Any) -> None: + with self.lock(): + if key in self.fixed_cache: + l = self.fixed_cache[key] + if len(l) < self.list_size and value not in l: + l.append(value) + elif key in self.cache: + self.cache.move_to_end(key) + l = self.cache[key] + if len(l) < self.list_size and value not in l: + l.append(value) + else: + if len(self.cache) >= self.cache_size: + self.cache.popitem(last=False) + self.cache[key] = [value] + + @contextmanager + def lock(self): + try: + self._lock.acquire() + yield + finally: + self._lock.release() diff --git a/examples/tutorial/handson5/inference/opt_fastapi.py b/examples/tutorial/handson5/inference/opt_fastapi.py new file mode 100644 index 000000000..cbfc2a22e --- /dev/null +++ b/examples/tutorial/handson5/inference/opt_fastapi.py @@ -0,0 +1,123 @@ +import argparse +import logging +import random +from typing import Optional + +import uvicorn +from energonai import QueueFullError, launch_engine +from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B +from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel, Field +from transformers import GPT2Tokenizer + +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError + + +class GenerationTaskReq(BaseModel): + max_tokens: int = Field(gt=0, le=256, example=64) + prompt: str = Field( + min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + + +app = FastAPI() + + +@app.post('/generation') +async def generate(data: GenerationTaskReq, request: Request): + logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') + key = (data.prompt, data.max_tokens) + try: + if cache is None: + raise MissCacheError() + outputs = cache.get(key) + output = random.choice(outputs) + logger.info('Cache hit') + except MissCacheError: + inputs = tokenizer(data.prompt, truncation=True, max_length=512) + inputs['max_tokens'] = data.max_tokens + inputs['top_k'] = data.top_k + inputs['top_p'] = data.top_p + inputs['temperature'] = data.temperature + try: + uid = id(data) + engine.submit(uid, inputs) + output = await engine.wait(uid) + output = tokenizer.decode(output, skip_special_tokens=True) + if cache is not None: + cache.add(key, output) + except QueueFullError as e: + raise HTTPException(status_code=406, detail=e.args[0]) + + return {'text': output} + + +@app.on_event("shutdown") +async def shutdown(*_): + engine.shutdown() + server.should_exit = True + server.force_exit = True + await server.shutdown() + + +def get_model_fn(model_name: str): + model_map = { + 'opt-125m': opt_125M, + 'opt-6.7b': opt_6B, + 'opt-30b': opt_30B, + 'opt-175b': opt_175B + } + return model_map[model_name] + + +def print_args(args: argparse.Namespace): + print('\n==> Args:') + for k, v in args.__dict__.items(): + print(f'{k} = {v}') + + +FIXED_CACHE_KEYS = [ + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), + ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) +] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--master_host', default='localhost') + parser.add_argument('--master_port', type=int, default=19990) + parser.add_argument('--rpc_port', type=int, default=19980) + parser.add_argument('--max_batch_size', type=int, default=8) + parser.add_argument('--pipe_size', type=int, default=1) + parser.add_argument('--queue_size', type=int, default=0) + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--checkpoint', default=None) + parser.add_argument('--cache_size', type=int, default=0) + parser.add_argument('--cache_list_size', type=int, default=1) + args = parser.parse_args() + print_args(args) + model_kwargs = {} + if args.checkpoint is not None: + model_kwargs['checkpoint'] = args.checkpoint + + logger = logging.getLogger(__name__) + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + if args.cache_size > 0: + cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) + else: + cache = None + engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, + pad_token_id=tokenizer.pad_token_id), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs) + config = uvicorn.Config(app, host=args.http_host, port=args.http_port) + server = uvicorn.Server(config=config) + server.run() diff --git a/examples/tutorial/handson5/inference/opt_server.py b/examples/tutorial/handson5/inference/opt_server.py new file mode 100644 index 000000000..8dab82622 --- /dev/null +++ b/examples/tutorial/handson5/inference/opt_server.py @@ -0,0 +1,122 @@ +import logging +import argparse +import random +from torch import Tensor +from pydantic import BaseModel, Field +from typing import Optional +from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B +from transformers import GPT2Tokenizer +from energonai import launch_engine, QueueFullError +from sanic import Sanic +from sanic.request import Request +from sanic.response import json +from sanic_ext import validate, openapi +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError + + +class GenerationTaskReq(BaseModel): + max_tokens: int = Field(gt=0, le=256, example=64) + prompt: str = Field( + min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + top_k: Optional[int] = Field(default=None, gt=0, example=50) + top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) + temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) + + +app = Sanic('opt') + + +@app.post('/generation') +@openapi.body(GenerationTaskReq) +@validate(json=GenerationTaskReq) +async def generate(request: Request, body: GenerationTaskReq): + logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}') + key = (body.prompt, body.max_tokens) + try: + if cache is None: + raise MissCacheError() + outputs = cache.get(key) + output = random.choice(outputs) + logger.info('Cache hit') + except MissCacheError: + inputs = tokenizer(body.prompt, truncation=True, max_length=512) + inputs['max_tokens'] = body.max_tokens + inputs['top_k'] = body.top_k + inputs['top_p'] = body.top_p + inputs['temperature'] = body.temperature + try: + uid = id(body) + engine.submit(uid, inputs) + output = await engine.wait(uid) + assert isinstance(output, Tensor) + output = tokenizer.decode(output, skip_special_tokens=True) + if cache is not None: + cache.add(key, output) + except QueueFullError as e: + return json({'detail': e.args[0]}, status=406) + + return json({'text': output}) + + +@app.after_server_stop +def shutdown(*_): + engine.shutdown() + + +def get_model_fn(model_name: str): + model_map = { + 'opt-125m': opt_125M, + 'opt-6.7b': opt_6B, + 'opt-30b': opt_30B, + 'opt-175b': opt_175B + } + return model_map[model_name] + + +def print_args(args: argparse.Namespace): + print('\n==> Args:') + for k, v in args.__dict__.items(): + print(f'{k} = {v}') + + +FIXED_CACHE_KEYS = [ + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), + ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) +] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) + parser.add_argument('--tp', type=int, default=1) + parser.add_argument('--master_host', default='localhost') + parser.add_argument('--master_port', type=int, default=19990) + parser.add_argument('--rpc_port', type=int, default=19980) + parser.add_argument('--max_batch_size', type=int, default=8) + parser.add_argument('--pipe_size', type=int, default=1) + parser.add_argument('--queue_size', type=int, default=0) + parser.add_argument('--http_host', default='0.0.0.0') + parser.add_argument('--http_port', type=int, default=7070) + parser.add_argument('--checkpoint', default=None) + parser.add_argument('--cache_size', type=int, default=0) + parser.add_argument('--cache_list_size', type=int, default=1) + args = parser.parse_args() + print_args(args) + model_kwargs = {} + if args.checkpoint is not None: + model_kwargs['checkpoint'] = args.checkpoint + + logger = logging.getLogger(__name__) + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + if args.cache_size > 0: + cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) + else: + cache = None + engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, + pad_token_id=tokenizer.pad_token_id), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs) + app.run(args.http_host, args.http_port) diff --git a/examples/tutorial/handson5/inference/requirements.txt b/examples/tutorial/handson5/inference/requirements.txt new file mode 100644 index 000000000..d0970d587 --- /dev/null +++ b/examples/tutorial/handson5/inference/requirements.txt @@ -0,0 +1,8 @@ +fastapi==0.85.1 +locust==2.11.0 +pydantic==1.10.2 +sanic==22.9.0 +sanic_ext==22.9.0 +torch>=1.10.0 +transformers==4.23.1 +uvicorn==0.19.0 diff --git a/examples/tutorial/handson5/inference/script/process-opt-175b/README.md b/examples/tutorial/handson5/inference/script/process-opt-175b/README.md new file mode 100644 index 000000000..bc3cba72d --- /dev/null +++ b/examples/tutorial/handson5/inference/script/process-opt-175b/README.md @@ -0,0 +1,46 @@ +# Process OPT-175B weights + +You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this. + +First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`. + +Then, `cd metaseq`. + +To consolidate checkpoints to eliminate FSDP: + +```shell +bash metaseq/scripts/reshard_mp_launch_no_slurm.sh /checkpoint_last / 8 1 +``` + +You will get 8 files in ``, and you should have the following checksums: +``` +7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt +c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt +45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt +abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt +05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt +d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt +fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt +2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt +``` + +Copy `flat-meta.json` to ``. + +Then cd to this dir, and we unflatten parameters. + +```shell +bash unflat.sh / / +``` + +Finally, you will get 8 files in `` with following checksums: +``` +6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt +58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt +69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt +002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt +6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt +93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt +5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt +f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt +``` + diff --git a/examples/tutorial/handson5/inference/script/process-opt-175b/convert_ckpt.py b/examples/tutorial/handson5/inference/script/process-opt-175b/convert_ckpt.py new file mode 100644 index 000000000..a17ddd4fa --- /dev/null +++ b/examples/tutorial/handson5/inference/script/process-opt-175b/convert_ckpt.py @@ -0,0 +1,55 @@ +import argparse +import json +import os +import re +from collections import defaultdict + +import numpy as np +import torch + + +def load_json(path: str): + with open(path) as f: + return json.load(f) + + +def parse_shape_info(flat_dir: str): + data = load_json(os.path.join(flat_dir, 'shape.json')) + flat_info = defaultdict(lambda: defaultdict(list)) + for k, shape in data.items(): + matched = re.match(r'decoder.layers.\d+', k) + if matched is None: + flat_key = 'flat_param_0' + else: + flat_key = f'{matched[0]}.flat_param_0' + flat_info[flat_key]['names'].append(k) + flat_info[flat_key]['shapes'].append(shape) + flat_info[flat_key]['numels'].append(int(np.prod(shape))) + return flat_info + + +def convert(flat_dir: str, output_dir: str, part: int): + flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt') + output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt') + flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json')) + flat_sd = torch.load(flat_path) + print(f'Loaded flat state dict from {flat_path}') + output_sd = {} + for flat_key, param_meta in flat_meta.items(): + flat_param = flat_sd['model'][flat_key] + assert sum(param_meta['numels']) == flat_param.numel( + ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' + for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): + output_sd[name] = param.view(shape) + + torch.save(output_sd, output_path) + print(f'Saved unflat state dict to {output_path}') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('flat_dir') + parser.add_argument('output_dir') + parser.add_argument('part', type=int) + args = parser.parse_args() + convert(args.flat_dir, args.output_dir, args.part) diff --git a/examples/tutorial/handson5/inference/script/process-opt-175b/flat-meta.json b/examples/tutorial/handson5/inference/script/process-opt-175b/flat-meta.json new file mode 100644 index 000000000..59d285565 --- /dev/null +++ b/examples/tutorial/handson5/inference/script/process-opt-175b/flat-meta.json @@ -0,0 +1 @@ +{"flat_param_0": {"names": ["decoder.embed_tokens.weight", "decoder.embed_positions.weight", "decoder.layer_norm.weight", "decoder.layer_norm.bias"], "shapes": [[6284, 12288], [2050, 12288], [12288], [12288]], "numels": [77217792, 25190400, 12288, 12288]}, "decoder.layers.0.flat_param_0": {"names": ["decoder.layers.0.self_attn.qkv_proj.weight", "decoder.layers.0.self_attn.qkv_proj.bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.self_attn_layer_norm.weight", "decoder.layers.0.self_attn_layer_norm.bias", "decoder.layers.0.fc1.weight", "decoder.layers.0.fc1.bias", "decoder.layers.0.fc2.weight", "decoder.layers.0.fc2.bias", "decoder.layers.0.final_layer_norm.weight", "decoder.layers.0.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.1.flat_param_0": {"names": ["decoder.layers.1.self_attn.qkv_proj.weight", "decoder.layers.1.self_attn.qkv_proj.bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.self_attn_layer_norm.weight", "decoder.layers.1.self_attn_layer_norm.bias", "decoder.layers.1.fc1.weight", "decoder.layers.1.fc1.bias", "decoder.layers.1.fc2.weight", "decoder.layers.1.fc2.bias", "decoder.layers.1.final_layer_norm.weight", "decoder.layers.1.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.2.flat_param_0": {"names": ["decoder.layers.2.self_attn.qkv_proj.weight", "decoder.layers.2.self_attn.qkv_proj.bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.self_attn_layer_norm.weight", "decoder.layers.2.self_attn_layer_norm.bias", "decoder.layers.2.fc1.weight", "decoder.layers.2.fc1.bias", "decoder.layers.2.fc2.weight", "decoder.layers.2.fc2.bias", "decoder.layers.2.final_layer_norm.weight", "decoder.layers.2.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.3.flat_param_0": {"names": ["decoder.layers.3.self_attn.qkv_proj.weight", "decoder.layers.3.self_attn.qkv_proj.bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.self_attn_layer_norm.weight", "decoder.layers.3.self_attn_layer_norm.bias", "decoder.layers.3.fc1.weight", "decoder.layers.3.fc1.bias", "decoder.layers.3.fc2.weight", "decoder.layers.3.fc2.bias", "decoder.layers.3.final_layer_norm.weight", "decoder.layers.3.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.4.flat_param_0": {"names": ["decoder.layers.4.self_attn.qkv_proj.weight", "decoder.layers.4.self_attn.qkv_proj.bias", "decoder.layers.4.self_attn.out_proj.weight", "decoder.layers.4.self_attn.out_proj.bias", "decoder.layers.4.self_attn_layer_norm.weight", "decoder.layers.4.self_attn_layer_norm.bias", "decoder.layers.4.fc1.weight", "decoder.layers.4.fc1.bias", "decoder.layers.4.fc2.weight", "decoder.layers.4.fc2.bias", "decoder.layers.4.final_layer_norm.weight", "decoder.layers.4.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.5.flat_param_0": {"names": ["decoder.layers.5.self_attn.qkv_proj.weight", "decoder.layers.5.self_attn.qkv_proj.bias", "decoder.layers.5.self_attn.out_proj.weight", "decoder.layers.5.self_attn.out_proj.bias", "decoder.layers.5.self_attn_layer_norm.weight", "decoder.layers.5.self_attn_layer_norm.bias", "decoder.layers.5.fc1.weight", "decoder.layers.5.fc1.bias", "decoder.layers.5.fc2.weight", "decoder.layers.5.fc2.bias", "decoder.layers.5.final_layer_norm.weight", "decoder.layers.5.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.6.flat_param_0": {"names": ["decoder.layers.6.self_attn.qkv_proj.weight", "decoder.layers.6.self_attn.qkv_proj.bias", "decoder.layers.6.self_attn.out_proj.weight", "decoder.layers.6.self_attn.out_proj.bias", "decoder.layers.6.self_attn_layer_norm.weight", "decoder.layers.6.self_attn_layer_norm.bias", "decoder.layers.6.fc1.weight", "decoder.layers.6.fc1.bias", "decoder.layers.6.fc2.weight", "decoder.layers.6.fc2.bias", "decoder.layers.6.final_layer_norm.weight", "decoder.layers.6.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.7.flat_param_0": {"names": ["decoder.layers.7.self_attn.qkv_proj.weight", "decoder.layers.7.self_attn.qkv_proj.bias", "decoder.layers.7.self_attn.out_proj.weight", "decoder.layers.7.self_attn.out_proj.bias", "decoder.layers.7.self_attn_layer_norm.weight", "decoder.layers.7.self_attn_layer_norm.bias", "decoder.layers.7.fc1.weight", "decoder.layers.7.fc1.bias", "decoder.layers.7.fc2.weight", "decoder.layers.7.fc2.bias", "decoder.layers.7.final_layer_norm.weight", "decoder.layers.7.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.8.flat_param_0": {"names": ["decoder.layers.8.self_attn.qkv_proj.weight", "decoder.layers.8.self_attn.qkv_proj.bias", "decoder.layers.8.self_attn.out_proj.weight", "decoder.layers.8.self_attn.out_proj.bias", "decoder.layers.8.self_attn_layer_norm.weight", "decoder.layers.8.self_attn_layer_norm.bias", "decoder.layers.8.fc1.weight", "decoder.layers.8.fc1.bias", "decoder.layers.8.fc2.weight", "decoder.layers.8.fc2.bias", "decoder.layers.8.final_layer_norm.weight", "decoder.layers.8.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.9.flat_param_0": {"names": ["decoder.layers.9.self_attn.qkv_proj.weight", "decoder.layers.9.self_attn.qkv_proj.bias", "decoder.layers.9.self_attn.out_proj.weight", "decoder.layers.9.self_attn.out_proj.bias", "decoder.layers.9.self_attn_layer_norm.weight", "decoder.layers.9.self_attn_layer_norm.bias", "decoder.layers.9.fc1.weight", "decoder.layers.9.fc1.bias", "decoder.layers.9.fc2.weight", "decoder.layers.9.fc2.bias", "decoder.layers.9.final_layer_norm.weight", "decoder.layers.9.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.10.flat_param_0": {"names": ["decoder.layers.10.self_attn.qkv_proj.weight", "decoder.layers.10.self_attn.qkv_proj.bias", "decoder.layers.10.self_attn.out_proj.weight", "decoder.layers.10.self_attn.out_proj.bias", "decoder.layers.10.self_attn_layer_norm.weight", "decoder.layers.10.self_attn_layer_norm.bias", "decoder.layers.10.fc1.weight", "decoder.layers.10.fc1.bias", "decoder.layers.10.fc2.weight", "decoder.layers.10.fc2.bias", "decoder.layers.10.final_layer_norm.weight", "decoder.layers.10.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.11.flat_param_0": {"names": ["decoder.layers.11.self_attn.qkv_proj.weight", "decoder.layers.11.self_attn.qkv_proj.bias", "decoder.layers.11.self_attn.out_proj.weight", "decoder.layers.11.self_attn.out_proj.bias", "decoder.layers.11.self_attn_layer_norm.weight", "decoder.layers.11.self_attn_layer_norm.bias", "decoder.layers.11.fc1.weight", "decoder.layers.11.fc1.bias", "decoder.layers.11.fc2.weight", "decoder.layers.11.fc2.bias", "decoder.layers.11.final_layer_norm.weight", "decoder.layers.11.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.12.flat_param_0": {"names": ["decoder.layers.12.self_attn.qkv_proj.weight", "decoder.layers.12.self_attn.qkv_proj.bias", "decoder.layers.12.self_attn.out_proj.weight", "decoder.layers.12.self_attn.out_proj.bias", "decoder.layers.12.self_attn_layer_norm.weight", "decoder.layers.12.self_attn_layer_norm.bias", "decoder.layers.12.fc1.weight", "decoder.layers.12.fc1.bias", "decoder.layers.12.fc2.weight", "decoder.layers.12.fc2.bias", "decoder.layers.12.final_layer_norm.weight", "decoder.layers.12.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.13.flat_param_0": {"names": ["decoder.layers.13.self_attn.qkv_proj.weight", "decoder.layers.13.self_attn.qkv_proj.bias", "decoder.layers.13.self_attn.out_proj.weight", "decoder.layers.13.self_attn.out_proj.bias", "decoder.layers.13.self_attn_layer_norm.weight", "decoder.layers.13.self_attn_layer_norm.bias", "decoder.layers.13.fc1.weight", "decoder.layers.13.fc1.bias", "decoder.layers.13.fc2.weight", "decoder.layers.13.fc2.bias", "decoder.layers.13.final_layer_norm.weight", "decoder.layers.13.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.14.flat_param_0": {"names": ["decoder.layers.14.self_attn.qkv_proj.weight", "decoder.layers.14.self_attn.qkv_proj.bias", "decoder.layers.14.self_attn.out_proj.weight", "decoder.layers.14.self_attn.out_proj.bias", "decoder.layers.14.self_attn_layer_norm.weight", "decoder.layers.14.self_attn_layer_norm.bias", "decoder.layers.14.fc1.weight", "decoder.layers.14.fc1.bias", "decoder.layers.14.fc2.weight", "decoder.layers.14.fc2.bias", "decoder.layers.14.final_layer_norm.weight", "decoder.layers.14.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.15.flat_param_0": {"names": ["decoder.layers.15.self_attn.qkv_proj.weight", "decoder.layers.15.self_attn.qkv_proj.bias", "decoder.layers.15.self_attn.out_proj.weight", "decoder.layers.15.self_attn.out_proj.bias", "decoder.layers.15.self_attn_layer_norm.weight", "decoder.layers.15.self_attn_layer_norm.bias", "decoder.layers.15.fc1.weight", "decoder.layers.15.fc1.bias", "decoder.layers.15.fc2.weight", "decoder.layers.15.fc2.bias", "decoder.layers.15.final_layer_norm.weight", "decoder.layers.15.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.16.flat_param_0": {"names": ["decoder.layers.16.self_attn.qkv_proj.weight", "decoder.layers.16.self_attn.qkv_proj.bias", "decoder.layers.16.self_attn.out_proj.weight", "decoder.layers.16.self_attn.out_proj.bias", "decoder.layers.16.self_attn_layer_norm.weight", "decoder.layers.16.self_attn_layer_norm.bias", "decoder.layers.16.fc1.weight", "decoder.layers.16.fc1.bias", "decoder.layers.16.fc2.weight", "decoder.layers.16.fc2.bias", "decoder.layers.16.final_layer_norm.weight", "decoder.layers.16.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.17.flat_param_0": {"names": ["decoder.layers.17.self_attn.qkv_proj.weight", "decoder.layers.17.self_attn.qkv_proj.bias", "decoder.layers.17.self_attn.out_proj.weight", "decoder.layers.17.self_attn.out_proj.bias", "decoder.layers.17.self_attn_layer_norm.weight", "decoder.layers.17.self_attn_layer_norm.bias", "decoder.layers.17.fc1.weight", "decoder.layers.17.fc1.bias", "decoder.layers.17.fc2.weight", "decoder.layers.17.fc2.bias", "decoder.layers.17.final_layer_norm.weight", "decoder.layers.17.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.18.flat_param_0": {"names": ["decoder.layers.18.self_attn.qkv_proj.weight", "decoder.layers.18.self_attn.qkv_proj.bias", "decoder.layers.18.self_attn.out_proj.weight", "decoder.layers.18.self_attn.out_proj.bias", "decoder.layers.18.self_attn_layer_norm.weight", "decoder.layers.18.self_attn_layer_norm.bias", "decoder.layers.18.fc1.weight", "decoder.layers.18.fc1.bias", "decoder.layers.18.fc2.weight", "decoder.layers.18.fc2.bias", "decoder.layers.18.final_layer_norm.weight", "decoder.layers.18.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.19.flat_param_0": {"names": ["decoder.layers.19.self_attn.qkv_proj.weight", "decoder.layers.19.self_attn.qkv_proj.bias", "decoder.layers.19.self_attn.out_proj.weight", "decoder.layers.19.self_attn.out_proj.bias", "decoder.layers.19.self_attn_layer_norm.weight", "decoder.layers.19.self_attn_layer_norm.bias", "decoder.layers.19.fc1.weight", "decoder.layers.19.fc1.bias", "decoder.layers.19.fc2.weight", "decoder.layers.19.fc2.bias", "decoder.layers.19.final_layer_norm.weight", "decoder.layers.19.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.20.flat_param_0": {"names": ["decoder.layers.20.self_attn.qkv_proj.weight", "decoder.layers.20.self_attn.qkv_proj.bias", "decoder.layers.20.self_attn.out_proj.weight", "decoder.layers.20.self_attn.out_proj.bias", "decoder.layers.20.self_attn_layer_norm.weight", "decoder.layers.20.self_attn_layer_norm.bias", "decoder.layers.20.fc1.weight", "decoder.layers.20.fc1.bias", "decoder.layers.20.fc2.weight", "decoder.layers.20.fc2.bias", "decoder.layers.20.final_layer_norm.weight", "decoder.layers.20.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.21.flat_param_0": {"names": ["decoder.layers.21.self_attn.qkv_proj.weight", "decoder.layers.21.self_attn.qkv_proj.bias", "decoder.layers.21.self_attn.out_proj.weight", "decoder.layers.21.self_attn.out_proj.bias", "decoder.layers.21.self_attn_layer_norm.weight", "decoder.layers.21.self_attn_layer_norm.bias", "decoder.layers.21.fc1.weight", "decoder.layers.21.fc1.bias", "decoder.layers.21.fc2.weight", "decoder.layers.21.fc2.bias", "decoder.layers.21.final_layer_norm.weight", "decoder.layers.21.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.22.flat_param_0": {"names": ["decoder.layers.22.self_attn.qkv_proj.weight", "decoder.layers.22.self_attn.qkv_proj.bias", "decoder.layers.22.self_attn.out_proj.weight", "decoder.layers.22.self_attn.out_proj.bias", "decoder.layers.22.self_attn_layer_norm.weight", "decoder.layers.22.self_attn_layer_norm.bias", "decoder.layers.22.fc1.weight", "decoder.layers.22.fc1.bias", "decoder.layers.22.fc2.weight", "decoder.layers.22.fc2.bias", "decoder.layers.22.final_layer_norm.weight", "decoder.layers.22.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.23.flat_param_0": {"names": ["decoder.layers.23.self_attn.qkv_proj.weight", "decoder.layers.23.self_attn.qkv_proj.bias", "decoder.layers.23.self_attn.out_proj.weight", "decoder.layers.23.self_attn.out_proj.bias", "decoder.layers.23.self_attn_layer_norm.weight", "decoder.layers.23.self_attn_layer_norm.bias", "decoder.layers.23.fc1.weight", "decoder.layers.23.fc1.bias", "decoder.layers.23.fc2.weight", "decoder.layers.23.fc2.bias", "decoder.layers.23.final_layer_norm.weight", "decoder.layers.23.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.24.flat_param_0": {"names": ["decoder.layers.24.self_attn.qkv_proj.weight", "decoder.layers.24.self_attn.qkv_proj.bias", "decoder.layers.24.self_attn.out_proj.weight", "decoder.layers.24.self_attn.out_proj.bias", "decoder.layers.24.self_attn_layer_norm.weight", "decoder.layers.24.self_attn_layer_norm.bias", "decoder.layers.24.fc1.weight", "decoder.layers.24.fc1.bias", "decoder.layers.24.fc2.weight", "decoder.layers.24.fc2.bias", "decoder.layers.24.final_layer_norm.weight", "decoder.layers.24.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.25.flat_param_0": {"names": ["decoder.layers.25.self_attn.qkv_proj.weight", "decoder.layers.25.self_attn.qkv_proj.bias", "decoder.layers.25.self_attn.out_proj.weight", "decoder.layers.25.self_attn.out_proj.bias", "decoder.layers.25.self_attn_layer_norm.weight", "decoder.layers.25.self_attn_layer_norm.bias", "decoder.layers.25.fc1.weight", "decoder.layers.25.fc1.bias", "decoder.layers.25.fc2.weight", "decoder.layers.25.fc2.bias", "decoder.layers.25.final_layer_norm.weight", "decoder.layers.25.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.26.flat_param_0": {"names": ["decoder.layers.26.self_attn.qkv_proj.weight", "decoder.layers.26.self_attn.qkv_proj.bias", "decoder.layers.26.self_attn.out_proj.weight", "decoder.layers.26.self_attn.out_proj.bias", "decoder.layers.26.self_attn_layer_norm.weight", "decoder.layers.26.self_attn_layer_norm.bias", "decoder.layers.26.fc1.weight", "decoder.layers.26.fc1.bias", "decoder.layers.26.fc2.weight", "decoder.layers.26.fc2.bias", "decoder.layers.26.final_layer_norm.weight", "decoder.layers.26.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.27.flat_param_0": {"names": ["decoder.layers.27.self_attn.qkv_proj.weight", "decoder.layers.27.self_attn.qkv_proj.bias", "decoder.layers.27.self_attn.out_proj.weight", "decoder.layers.27.self_attn.out_proj.bias", "decoder.layers.27.self_attn_layer_norm.weight", "decoder.layers.27.self_attn_layer_norm.bias", "decoder.layers.27.fc1.weight", "decoder.layers.27.fc1.bias", "decoder.layers.27.fc2.weight", "decoder.layers.27.fc2.bias", "decoder.layers.27.final_layer_norm.weight", "decoder.layers.27.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.28.flat_param_0": {"names": ["decoder.layers.28.self_attn.qkv_proj.weight", "decoder.layers.28.self_attn.qkv_proj.bias", "decoder.layers.28.self_attn.out_proj.weight", "decoder.layers.28.self_attn.out_proj.bias", "decoder.layers.28.self_attn_layer_norm.weight", "decoder.layers.28.self_attn_layer_norm.bias", "decoder.layers.28.fc1.weight", "decoder.layers.28.fc1.bias", "decoder.layers.28.fc2.weight", "decoder.layers.28.fc2.bias", "decoder.layers.28.final_layer_norm.weight", "decoder.layers.28.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.29.flat_param_0": {"names": ["decoder.layers.29.self_attn.qkv_proj.weight", "decoder.layers.29.self_attn.qkv_proj.bias", "decoder.layers.29.self_attn.out_proj.weight", "decoder.layers.29.self_attn.out_proj.bias", "decoder.layers.29.self_attn_layer_norm.weight", "decoder.layers.29.self_attn_layer_norm.bias", "decoder.layers.29.fc1.weight", "decoder.layers.29.fc1.bias", "decoder.layers.29.fc2.weight", "decoder.layers.29.fc2.bias", "decoder.layers.29.final_layer_norm.weight", "decoder.layers.29.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.30.flat_param_0": {"names": ["decoder.layers.30.self_attn.qkv_proj.weight", "decoder.layers.30.self_attn.qkv_proj.bias", "decoder.layers.30.self_attn.out_proj.weight", "decoder.layers.30.self_attn.out_proj.bias", "decoder.layers.30.self_attn_layer_norm.weight", "decoder.layers.30.self_attn_layer_norm.bias", "decoder.layers.30.fc1.weight", "decoder.layers.30.fc1.bias", "decoder.layers.30.fc2.weight", "decoder.layers.30.fc2.bias", "decoder.layers.30.final_layer_norm.weight", "decoder.layers.30.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.31.flat_param_0": {"names": ["decoder.layers.31.self_attn.qkv_proj.weight", "decoder.layers.31.self_attn.qkv_proj.bias", "decoder.layers.31.self_attn.out_proj.weight", "decoder.layers.31.self_attn.out_proj.bias", "decoder.layers.31.self_attn_layer_norm.weight", "decoder.layers.31.self_attn_layer_norm.bias", "decoder.layers.31.fc1.weight", "decoder.layers.31.fc1.bias", "decoder.layers.31.fc2.weight", "decoder.layers.31.fc2.bias", "decoder.layers.31.final_layer_norm.weight", "decoder.layers.31.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.32.flat_param_0": {"names": ["decoder.layers.32.self_attn.qkv_proj.weight", "decoder.layers.32.self_attn.qkv_proj.bias", "decoder.layers.32.self_attn.out_proj.weight", "decoder.layers.32.self_attn.out_proj.bias", "decoder.layers.32.self_attn_layer_norm.weight", "decoder.layers.32.self_attn_layer_norm.bias", "decoder.layers.32.fc1.weight", "decoder.layers.32.fc1.bias", "decoder.layers.32.fc2.weight", "decoder.layers.32.fc2.bias", "decoder.layers.32.final_layer_norm.weight", "decoder.layers.32.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.33.flat_param_0": {"names": ["decoder.layers.33.self_attn.qkv_proj.weight", "decoder.layers.33.self_attn.qkv_proj.bias", "decoder.layers.33.self_attn.out_proj.weight", "decoder.layers.33.self_attn.out_proj.bias", "decoder.layers.33.self_attn_layer_norm.weight", "decoder.layers.33.self_attn_layer_norm.bias", "decoder.layers.33.fc1.weight", "decoder.layers.33.fc1.bias", "decoder.layers.33.fc2.weight", "decoder.layers.33.fc2.bias", "decoder.layers.33.final_layer_norm.weight", "decoder.layers.33.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.34.flat_param_0": {"names": ["decoder.layers.34.self_attn.qkv_proj.weight", "decoder.layers.34.self_attn.qkv_proj.bias", "decoder.layers.34.self_attn.out_proj.weight", "decoder.layers.34.self_attn.out_proj.bias", "decoder.layers.34.self_attn_layer_norm.weight", "decoder.layers.34.self_attn_layer_norm.bias", "decoder.layers.34.fc1.weight", "decoder.layers.34.fc1.bias", "decoder.layers.34.fc2.weight", "decoder.layers.34.fc2.bias", "decoder.layers.34.final_layer_norm.weight", "decoder.layers.34.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.35.flat_param_0": {"names": ["decoder.layers.35.self_attn.qkv_proj.weight", "decoder.layers.35.self_attn.qkv_proj.bias", "decoder.layers.35.self_attn.out_proj.weight", "decoder.layers.35.self_attn.out_proj.bias", "decoder.layers.35.self_attn_layer_norm.weight", "decoder.layers.35.self_attn_layer_norm.bias", "decoder.layers.35.fc1.weight", "decoder.layers.35.fc1.bias", "decoder.layers.35.fc2.weight", "decoder.layers.35.fc2.bias", "decoder.layers.35.final_layer_norm.weight", "decoder.layers.35.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.36.flat_param_0": {"names": ["decoder.layers.36.self_attn.qkv_proj.weight", "decoder.layers.36.self_attn.qkv_proj.bias", "decoder.layers.36.self_attn.out_proj.weight", "decoder.layers.36.self_attn.out_proj.bias", "decoder.layers.36.self_attn_layer_norm.weight", "decoder.layers.36.self_attn_layer_norm.bias", "decoder.layers.36.fc1.weight", "decoder.layers.36.fc1.bias", "decoder.layers.36.fc2.weight", "decoder.layers.36.fc2.bias", "decoder.layers.36.final_layer_norm.weight", "decoder.layers.36.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.37.flat_param_0": {"names": ["decoder.layers.37.self_attn.qkv_proj.weight", "decoder.layers.37.self_attn.qkv_proj.bias", "decoder.layers.37.self_attn.out_proj.weight", "decoder.layers.37.self_attn.out_proj.bias", "decoder.layers.37.self_attn_layer_norm.weight", "decoder.layers.37.self_attn_layer_norm.bias", "decoder.layers.37.fc1.weight", "decoder.layers.37.fc1.bias", "decoder.layers.37.fc2.weight", "decoder.layers.37.fc2.bias", "decoder.layers.37.final_layer_norm.weight", "decoder.layers.37.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.38.flat_param_0": {"names": ["decoder.layers.38.self_attn.qkv_proj.weight", "decoder.layers.38.self_attn.qkv_proj.bias", "decoder.layers.38.self_attn.out_proj.weight", "decoder.layers.38.self_attn.out_proj.bias", "decoder.layers.38.self_attn_layer_norm.weight", "decoder.layers.38.self_attn_layer_norm.bias", "decoder.layers.38.fc1.weight", "decoder.layers.38.fc1.bias", "decoder.layers.38.fc2.weight", "decoder.layers.38.fc2.bias", "decoder.layers.38.final_layer_norm.weight", "decoder.layers.38.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.39.flat_param_0": {"names": ["decoder.layers.39.self_attn.qkv_proj.weight", "decoder.layers.39.self_attn.qkv_proj.bias", "decoder.layers.39.self_attn.out_proj.weight", "decoder.layers.39.self_attn.out_proj.bias", "decoder.layers.39.self_attn_layer_norm.weight", "decoder.layers.39.self_attn_layer_norm.bias", "decoder.layers.39.fc1.weight", "decoder.layers.39.fc1.bias", "decoder.layers.39.fc2.weight", "decoder.layers.39.fc2.bias", "decoder.layers.39.final_layer_norm.weight", "decoder.layers.39.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.40.flat_param_0": {"names": ["decoder.layers.40.self_attn.qkv_proj.weight", "decoder.layers.40.self_attn.qkv_proj.bias", "decoder.layers.40.self_attn.out_proj.weight", "decoder.layers.40.self_attn.out_proj.bias", "decoder.layers.40.self_attn_layer_norm.weight", "decoder.layers.40.self_attn_layer_norm.bias", "decoder.layers.40.fc1.weight", "decoder.layers.40.fc1.bias", "decoder.layers.40.fc2.weight", "decoder.layers.40.fc2.bias", "decoder.layers.40.final_layer_norm.weight", "decoder.layers.40.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.41.flat_param_0": {"names": ["decoder.layers.41.self_attn.qkv_proj.weight", "decoder.layers.41.self_attn.qkv_proj.bias", "decoder.layers.41.self_attn.out_proj.weight", "decoder.layers.41.self_attn.out_proj.bias", "decoder.layers.41.self_attn_layer_norm.weight", "decoder.layers.41.self_attn_layer_norm.bias", "decoder.layers.41.fc1.weight", "decoder.layers.41.fc1.bias", "decoder.layers.41.fc2.weight", "decoder.layers.41.fc2.bias", "decoder.layers.41.final_layer_norm.weight", "decoder.layers.41.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.42.flat_param_0": {"names": ["decoder.layers.42.self_attn.qkv_proj.weight", "decoder.layers.42.self_attn.qkv_proj.bias", "decoder.layers.42.self_attn.out_proj.weight", "decoder.layers.42.self_attn.out_proj.bias", "decoder.layers.42.self_attn_layer_norm.weight", "decoder.layers.42.self_attn_layer_norm.bias", "decoder.layers.42.fc1.weight", "decoder.layers.42.fc1.bias", "decoder.layers.42.fc2.weight", "decoder.layers.42.fc2.bias", "decoder.layers.42.final_layer_norm.weight", "decoder.layers.42.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.43.flat_param_0": {"names": ["decoder.layers.43.self_attn.qkv_proj.weight", "decoder.layers.43.self_attn.qkv_proj.bias", "decoder.layers.43.self_attn.out_proj.weight", "decoder.layers.43.self_attn.out_proj.bias", "decoder.layers.43.self_attn_layer_norm.weight", "decoder.layers.43.self_attn_layer_norm.bias", "decoder.layers.43.fc1.weight", "decoder.layers.43.fc1.bias", "decoder.layers.43.fc2.weight", "decoder.layers.43.fc2.bias", "decoder.layers.43.final_layer_norm.weight", "decoder.layers.43.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.44.flat_param_0": {"names": ["decoder.layers.44.self_attn.qkv_proj.weight", "decoder.layers.44.self_attn.qkv_proj.bias", "decoder.layers.44.self_attn.out_proj.weight", "decoder.layers.44.self_attn.out_proj.bias", "decoder.layers.44.self_attn_layer_norm.weight", "decoder.layers.44.self_attn_layer_norm.bias", "decoder.layers.44.fc1.weight", "decoder.layers.44.fc1.bias", "decoder.layers.44.fc2.weight", "decoder.layers.44.fc2.bias", "decoder.layers.44.final_layer_norm.weight", "decoder.layers.44.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.45.flat_param_0": {"names": ["decoder.layers.45.self_attn.qkv_proj.weight", "decoder.layers.45.self_attn.qkv_proj.bias", "decoder.layers.45.self_attn.out_proj.weight", "decoder.layers.45.self_attn.out_proj.bias", "decoder.layers.45.self_attn_layer_norm.weight", "decoder.layers.45.self_attn_layer_norm.bias", "decoder.layers.45.fc1.weight", "decoder.layers.45.fc1.bias", "decoder.layers.45.fc2.weight", "decoder.layers.45.fc2.bias", "decoder.layers.45.final_layer_norm.weight", "decoder.layers.45.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.46.flat_param_0": {"names": ["decoder.layers.46.self_attn.qkv_proj.weight", "decoder.layers.46.self_attn.qkv_proj.bias", "decoder.layers.46.self_attn.out_proj.weight", "decoder.layers.46.self_attn.out_proj.bias", "decoder.layers.46.self_attn_layer_norm.weight", "decoder.layers.46.self_attn_layer_norm.bias", "decoder.layers.46.fc1.weight", "decoder.layers.46.fc1.bias", "decoder.layers.46.fc2.weight", "decoder.layers.46.fc2.bias", "decoder.layers.46.final_layer_norm.weight", "decoder.layers.46.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.47.flat_param_0": {"names": ["decoder.layers.47.self_attn.qkv_proj.weight", "decoder.layers.47.self_attn.qkv_proj.bias", "decoder.layers.47.self_attn.out_proj.weight", "decoder.layers.47.self_attn.out_proj.bias", "decoder.layers.47.self_attn_layer_norm.weight", "decoder.layers.47.self_attn_layer_norm.bias", "decoder.layers.47.fc1.weight", "decoder.layers.47.fc1.bias", "decoder.layers.47.fc2.weight", "decoder.layers.47.fc2.bias", "decoder.layers.47.final_layer_norm.weight", "decoder.layers.47.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.48.flat_param_0": {"names": ["decoder.layers.48.self_attn.qkv_proj.weight", "decoder.layers.48.self_attn.qkv_proj.bias", "decoder.layers.48.self_attn.out_proj.weight", "decoder.layers.48.self_attn.out_proj.bias", "decoder.layers.48.self_attn_layer_norm.weight", "decoder.layers.48.self_attn_layer_norm.bias", "decoder.layers.48.fc1.weight", "decoder.layers.48.fc1.bias", "decoder.layers.48.fc2.weight", "decoder.layers.48.fc2.bias", "decoder.layers.48.final_layer_norm.weight", "decoder.layers.48.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.49.flat_param_0": {"names": ["decoder.layers.49.self_attn.qkv_proj.weight", "decoder.layers.49.self_attn.qkv_proj.bias", "decoder.layers.49.self_attn.out_proj.weight", "decoder.layers.49.self_attn.out_proj.bias", "decoder.layers.49.self_attn_layer_norm.weight", "decoder.layers.49.self_attn_layer_norm.bias", "decoder.layers.49.fc1.weight", "decoder.layers.49.fc1.bias", "decoder.layers.49.fc2.weight", "decoder.layers.49.fc2.bias", "decoder.layers.49.final_layer_norm.weight", "decoder.layers.49.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.50.flat_param_0": {"names": ["decoder.layers.50.self_attn.qkv_proj.weight", "decoder.layers.50.self_attn.qkv_proj.bias", "decoder.layers.50.self_attn.out_proj.weight", "decoder.layers.50.self_attn.out_proj.bias", "decoder.layers.50.self_attn_layer_norm.weight", "decoder.layers.50.self_attn_layer_norm.bias", "decoder.layers.50.fc1.weight", "decoder.layers.50.fc1.bias", "decoder.layers.50.fc2.weight", "decoder.layers.50.fc2.bias", "decoder.layers.50.final_layer_norm.weight", "decoder.layers.50.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.51.flat_param_0": {"names": ["decoder.layers.51.self_attn.qkv_proj.weight", "decoder.layers.51.self_attn.qkv_proj.bias", "decoder.layers.51.self_attn.out_proj.weight", "decoder.layers.51.self_attn.out_proj.bias", "decoder.layers.51.self_attn_layer_norm.weight", "decoder.layers.51.self_attn_layer_norm.bias", "decoder.layers.51.fc1.weight", "decoder.layers.51.fc1.bias", "decoder.layers.51.fc2.weight", "decoder.layers.51.fc2.bias", "decoder.layers.51.final_layer_norm.weight", "decoder.layers.51.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.52.flat_param_0": {"names": ["decoder.layers.52.self_attn.qkv_proj.weight", "decoder.layers.52.self_attn.qkv_proj.bias", "decoder.layers.52.self_attn.out_proj.weight", "decoder.layers.52.self_attn.out_proj.bias", "decoder.layers.52.self_attn_layer_norm.weight", "decoder.layers.52.self_attn_layer_norm.bias", "decoder.layers.52.fc1.weight", "decoder.layers.52.fc1.bias", "decoder.layers.52.fc2.weight", "decoder.layers.52.fc2.bias", "decoder.layers.52.final_layer_norm.weight", "decoder.layers.52.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.53.flat_param_0": {"names": ["decoder.layers.53.self_attn.qkv_proj.weight", "decoder.layers.53.self_attn.qkv_proj.bias", "decoder.layers.53.self_attn.out_proj.weight", "decoder.layers.53.self_attn.out_proj.bias", "decoder.layers.53.self_attn_layer_norm.weight", "decoder.layers.53.self_attn_layer_norm.bias", "decoder.layers.53.fc1.weight", "decoder.layers.53.fc1.bias", "decoder.layers.53.fc2.weight", "decoder.layers.53.fc2.bias", "decoder.layers.53.final_layer_norm.weight", "decoder.layers.53.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.54.flat_param_0": {"names": ["decoder.layers.54.self_attn.qkv_proj.weight", "decoder.layers.54.self_attn.qkv_proj.bias", "decoder.layers.54.self_attn.out_proj.weight", "decoder.layers.54.self_attn.out_proj.bias", "decoder.layers.54.self_attn_layer_norm.weight", "decoder.layers.54.self_attn_layer_norm.bias", "decoder.layers.54.fc1.weight", "decoder.layers.54.fc1.bias", "decoder.layers.54.fc2.weight", "decoder.layers.54.fc2.bias", "decoder.layers.54.final_layer_norm.weight", "decoder.layers.54.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.55.flat_param_0": {"names": ["decoder.layers.55.self_attn.qkv_proj.weight", "decoder.layers.55.self_attn.qkv_proj.bias", "decoder.layers.55.self_attn.out_proj.weight", "decoder.layers.55.self_attn.out_proj.bias", "decoder.layers.55.self_attn_layer_norm.weight", "decoder.layers.55.self_attn_layer_norm.bias", "decoder.layers.55.fc1.weight", "decoder.layers.55.fc1.bias", "decoder.layers.55.fc2.weight", "decoder.layers.55.fc2.bias", "decoder.layers.55.final_layer_norm.weight", "decoder.layers.55.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.56.flat_param_0": {"names": ["decoder.layers.56.self_attn.qkv_proj.weight", "decoder.layers.56.self_attn.qkv_proj.bias", "decoder.layers.56.self_attn.out_proj.weight", "decoder.layers.56.self_attn.out_proj.bias", "decoder.layers.56.self_attn_layer_norm.weight", "decoder.layers.56.self_attn_layer_norm.bias", "decoder.layers.56.fc1.weight", "decoder.layers.56.fc1.bias", "decoder.layers.56.fc2.weight", "decoder.layers.56.fc2.bias", "decoder.layers.56.final_layer_norm.weight", "decoder.layers.56.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.57.flat_param_0": {"names": ["decoder.layers.57.self_attn.qkv_proj.weight", "decoder.layers.57.self_attn.qkv_proj.bias", "decoder.layers.57.self_attn.out_proj.weight", "decoder.layers.57.self_attn.out_proj.bias", "decoder.layers.57.self_attn_layer_norm.weight", "decoder.layers.57.self_attn_layer_norm.bias", "decoder.layers.57.fc1.weight", "decoder.layers.57.fc1.bias", "decoder.layers.57.fc2.weight", "decoder.layers.57.fc2.bias", "decoder.layers.57.final_layer_norm.weight", "decoder.layers.57.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.58.flat_param_0": {"names": ["decoder.layers.58.self_attn.qkv_proj.weight", "decoder.layers.58.self_attn.qkv_proj.bias", "decoder.layers.58.self_attn.out_proj.weight", "decoder.layers.58.self_attn.out_proj.bias", "decoder.layers.58.self_attn_layer_norm.weight", "decoder.layers.58.self_attn_layer_norm.bias", "decoder.layers.58.fc1.weight", "decoder.layers.58.fc1.bias", "decoder.layers.58.fc2.weight", "decoder.layers.58.fc2.bias", "decoder.layers.58.final_layer_norm.weight", "decoder.layers.58.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.59.flat_param_0": {"names": ["decoder.layers.59.self_attn.qkv_proj.weight", "decoder.layers.59.self_attn.qkv_proj.bias", "decoder.layers.59.self_attn.out_proj.weight", "decoder.layers.59.self_attn.out_proj.bias", "decoder.layers.59.self_attn_layer_norm.weight", "decoder.layers.59.self_attn_layer_norm.bias", "decoder.layers.59.fc1.weight", "decoder.layers.59.fc1.bias", "decoder.layers.59.fc2.weight", "decoder.layers.59.fc2.bias", "decoder.layers.59.final_layer_norm.weight", "decoder.layers.59.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.60.flat_param_0": {"names": ["decoder.layers.60.self_attn.qkv_proj.weight", "decoder.layers.60.self_attn.qkv_proj.bias", "decoder.layers.60.self_attn.out_proj.weight", "decoder.layers.60.self_attn.out_proj.bias", "decoder.layers.60.self_attn_layer_norm.weight", "decoder.layers.60.self_attn_layer_norm.bias", "decoder.layers.60.fc1.weight", "decoder.layers.60.fc1.bias", "decoder.layers.60.fc2.weight", "decoder.layers.60.fc2.bias", "decoder.layers.60.final_layer_norm.weight", "decoder.layers.60.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.61.flat_param_0": {"names": ["decoder.layers.61.self_attn.qkv_proj.weight", "decoder.layers.61.self_attn.qkv_proj.bias", "decoder.layers.61.self_attn.out_proj.weight", "decoder.layers.61.self_attn.out_proj.bias", "decoder.layers.61.self_attn_layer_norm.weight", "decoder.layers.61.self_attn_layer_norm.bias", "decoder.layers.61.fc1.weight", "decoder.layers.61.fc1.bias", "decoder.layers.61.fc2.weight", "decoder.layers.61.fc2.bias", "decoder.layers.61.final_layer_norm.weight", "decoder.layers.61.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.62.flat_param_0": {"names": ["decoder.layers.62.self_attn.qkv_proj.weight", "decoder.layers.62.self_attn.qkv_proj.bias", "decoder.layers.62.self_attn.out_proj.weight", "decoder.layers.62.self_attn.out_proj.bias", "decoder.layers.62.self_attn_layer_norm.weight", "decoder.layers.62.self_attn_layer_norm.bias", "decoder.layers.62.fc1.weight", "decoder.layers.62.fc1.bias", "decoder.layers.62.fc2.weight", "decoder.layers.62.fc2.bias", "decoder.layers.62.final_layer_norm.weight", "decoder.layers.62.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.63.flat_param_0": {"names": ["decoder.layers.63.self_attn.qkv_proj.weight", "decoder.layers.63.self_attn.qkv_proj.bias", "decoder.layers.63.self_attn.out_proj.weight", "decoder.layers.63.self_attn.out_proj.bias", "decoder.layers.63.self_attn_layer_norm.weight", "decoder.layers.63.self_attn_layer_norm.bias", "decoder.layers.63.fc1.weight", "decoder.layers.63.fc1.bias", "decoder.layers.63.fc2.weight", "decoder.layers.63.fc2.bias", "decoder.layers.63.final_layer_norm.weight", "decoder.layers.63.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.64.flat_param_0": {"names": ["decoder.layers.64.self_attn.qkv_proj.weight", "decoder.layers.64.self_attn.qkv_proj.bias", "decoder.layers.64.self_attn.out_proj.weight", "decoder.layers.64.self_attn.out_proj.bias", "decoder.layers.64.self_attn_layer_norm.weight", "decoder.layers.64.self_attn_layer_norm.bias", "decoder.layers.64.fc1.weight", "decoder.layers.64.fc1.bias", "decoder.layers.64.fc2.weight", "decoder.layers.64.fc2.bias", "decoder.layers.64.final_layer_norm.weight", "decoder.layers.64.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.65.flat_param_0": {"names": ["decoder.layers.65.self_attn.qkv_proj.weight", "decoder.layers.65.self_attn.qkv_proj.bias", "decoder.layers.65.self_attn.out_proj.weight", "decoder.layers.65.self_attn.out_proj.bias", "decoder.layers.65.self_attn_layer_norm.weight", "decoder.layers.65.self_attn_layer_norm.bias", "decoder.layers.65.fc1.weight", "decoder.layers.65.fc1.bias", "decoder.layers.65.fc2.weight", "decoder.layers.65.fc2.bias", "decoder.layers.65.final_layer_norm.weight", "decoder.layers.65.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.66.flat_param_0": {"names": ["decoder.layers.66.self_attn.qkv_proj.weight", "decoder.layers.66.self_attn.qkv_proj.bias", "decoder.layers.66.self_attn.out_proj.weight", "decoder.layers.66.self_attn.out_proj.bias", "decoder.layers.66.self_attn_layer_norm.weight", "decoder.layers.66.self_attn_layer_norm.bias", "decoder.layers.66.fc1.weight", "decoder.layers.66.fc1.bias", "decoder.layers.66.fc2.weight", "decoder.layers.66.fc2.bias", "decoder.layers.66.final_layer_norm.weight", "decoder.layers.66.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.67.flat_param_0": {"names": ["decoder.layers.67.self_attn.qkv_proj.weight", "decoder.layers.67.self_attn.qkv_proj.bias", "decoder.layers.67.self_attn.out_proj.weight", "decoder.layers.67.self_attn.out_proj.bias", "decoder.layers.67.self_attn_layer_norm.weight", "decoder.layers.67.self_attn_layer_norm.bias", "decoder.layers.67.fc1.weight", "decoder.layers.67.fc1.bias", "decoder.layers.67.fc2.weight", "decoder.layers.67.fc2.bias", "decoder.layers.67.final_layer_norm.weight", "decoder.layers.67.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.68.flat_param_0": {"names": ["decoder.layers.68.self_attn.qkv_proj.weight", "decoder.layers.68.self_attn.qkv_proj.bias", "decoder.layers.68.self_attn.out_proj.weight", "decoder.layers.68.self_attn.out_proj.bias", "decoder.layers.68.self_attn_layer_norm.weight", "decoder.layers.68.self_attn_layer_norm.bias", "decoder.layers.68.fc1.weight", "decoder.layers.68.fc1.bias", "decoder.layers.68.fc2.weight", "decoder.layers.68.fc2.bias", "decoder.layers.68.final_layer_norm.weight", "decoder.layers.68.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.69.flat_param_0": {"names": ["decoder.layers.69.self_attn.qkv_proj.weight", "decoder.layers.69.self_attn.qkv_proj.bias", "decoder.layers.69.self_attn.out_proj.weight", "decoder.layers.69.self_attn.out_proj.bias", "decoder.layers.69.self_attn_layer_norm.weight", "decoder.layers.69.self_attn_layer_norm.bias", "decoder.layers.69.fc1.weight", "decoder.layers.69.fc1.bias", "decoder.layers.69.fc2.weight", "decoder.layers.69.fc2.bias", "decoder.layers.69.final_layer_norm.weight", "decoder.layers.69.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.70.flat_param_0": {"names": ["decoder.layers.70.self_attn.qkv_proj.weight", "decoder.layers.70.self_attn.qkv_proj.bias", "decoder.layers.70.self_attn.out_proj.weight", "decoder.layers.70.self_attn.out_proj.bias", "decoder.layers.70.self_attn_layer_norm.weight", "decoder.layers.70.self_attn_layer_norm.bias", "decoder.layers.70.fc1.weight", "decoder.layers.70.fc1.bias", "decoder.layers.70.fc2.weight", "decoder.layers.70.fc2.bias", "decoder.layers.70.final_layer_norm.weight", "decoder.layers.70.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.71.flat_param_0": {"names": ["decoder.layers.71.self_attn.qkv_proj.weight", "decoder.layers.71.self_attn.qkv_proj.bias", "decoder.layers.71.self_attn.out_proj.weight", "decoder.layers.71.self_attn.out_proj.bias", "decoder.layers.71.self_attn_layer_norm.weight", "decoder.layers.71.self_attn_layer_norm.bias", "decoder.layers.71.fc1.weight", "decoder.layers.71.fc1.bias", "decoder.layers.71.fc2.weight", "decoder.layers.71.fc2.bias", "decoder.layers.71.final_layer_norm.weight", "decoder.layers.71.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.72.flat_param_0": {"names": ["decoder.layers.72.self_attn.qkv_proj.weight", "decoder.layers.72.self_attn.qkv_proj.bias", "decoder.layers.72.self_attn.out_proj.weight", "decoder.layers.72.self_attn.out_proj.bias", "decoder.layers.72.self_attn_layer_norm.weight", "decoder.layers.72.self_attn_layer_norm.bias", "decoder.layers.72.fc1.weight", "decoder.layers.72.fc1.bias", "decoder.layers.72.fc2.weight", "decoder.layers.72.fc2.bias", "decoder.layers.72.final_layer_norm.weight", "decoder.layers.72.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.73.flat_param_0": {"names": ["decoder.layers.73.self_attn.qkv_proj.weight", "decoder.layers.73.self_attn.qkv_proj.bias", "decoder.layers.73.self_attn.out_proj.weight", "decoder.layers.73.self_attn.out_proj.bias", "decoder.layers.73.self_attn_layer_norm.weight", "decoder.layers.73.self_attn_layer_norm.bias", "decoder.layers.73.fc1.weight", "decoder.layers.73.fc1.bias", "decoder.layers.73.fc2.weight", "decoder.layers.73.fc2.bias", "decoder.layers.73.final_layer_norm.weight", "decoder.layers.73.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.74.flat_param_0": {"names": ["decoder.layers.74.self_attn.qkv_proj.weight", "decoder.layers.74.self_attn.qkv_proj.bias", "decoder.layers.74.self_attn.out_proj.weight", "decoder.layers.74.self_attn.out_proj.bias", "decoder.layers.74.self_attn_layer_norm.weight", "decoder.layers.74.self_attn_layer_norm.bias", "decoder.layers.74.fc1.weight", "decoder.layers.74.fc1.bias", "decoder.layers.74.fc2.weight", "decoder.layers.74.fc2.bias", "decoder.layers.74.final_layer_norm.weight", "decoder.layers.74.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.75.flat_param_0": {"names": ["decoder.layers.75.self_attn.qkv_proj.weight", "decoder.layers.75.self_attn.qkv_proj.bias", "decoder.layers.75.self_attn.out_proj.weight", "decoder.layers.75.self_attn.out_proj.bias", "decoder.layers.75.self_attn_layer_norm.weight", "decoder.layers.75.self_attn_layer_norm.bias", "decoder.layers.75.fc1.weight", "decoder.layers.75.fc1.bias", "decoder.layers.75.fc2.weight", "decoder.layers.75.fc2.bias", "decoder.layers.75.final_layer_norm.weight", "decoder.layers.75.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.76.flat_param_0": {"names": ["decoder.layers.76.self_attn.qkv_proj.weight", "decoder.layers.76.self_attn.qkv_proj.bias", "decoder.layers.76.self_attn.out_proj.weight", "decoder.layers.76.self_attn.out_proj.bias", "decoder.layers.76.self_attn_layer_norm.weight", "decoder.layers.76.self_attn_layer_norm.bias", "decoder.layers.76.fc1.weight", "decoder.layers.76.fc1.bias", "decoder.layers.76.fc2.weight", "decoder.layers.76.fc2.bias", "decoder.layers.76.final_layer_norm.weight", "decoder.layers.76.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.77.flat_param_0": {"names": ["decoder.layers.77.self_attn.qkv_proj.weight", "decoder.layers.77.self_attn.qkv_proj.bias", "decoder.layers.77.self_attn.out_proj.weight", "decoder.layers.77.self_attn.out_proj.bias", "decoder.layers.77.self_attn_layer_norm.weight", "decoder.layers.77.self_attn_layer_norm.bias", "decoder.layers.77.fc1.weight", "decoder.layers.77.fc1.bias", "decoder.layers.77.fc2.weight", "decoder.layers.77.fc2.bias", "decoder.layers.77.final_layer_norm.weight", "decoder.layers.77.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.78.flat_param_0": {"names": ["decoder.layers.78.self_attn.qkv_proj.weight", "decoder.layers.78.self_attn.qkv_proj.bias", "decoder.layers.78.self_attn.out_proj.weight", "decoder.layers.78.self_attn.out_proj.bias", "decoder.layers.78.self_attn_layer_norm.weight", "decoder.layers.78.self_attn_layer_norm.bias", "decoder.layers.78.fc1.weight", "decoder.layers.78.fc1.bias", "decoder.layers.78.fc2.weight", "decoder.layers.78.fc2.bias", "decoder.layers.78.final_layer_norm.weight", "decoder.layers.78.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.79.flat_param_0": {"names": ["decoder.layers.79.self_attn.qkv_proj.weight", "decoder.layers.79.self_attn.qkv_proj.bias", "decoder.layers.79.self_attn.out_proj.weight", "decoder.layers.79.self_attn.out_proj.bias", "decoder.layers.79.self_attn_layer_norm.weight", "decoder.layers.79.self_attn_layer_norm.bias", "decoder.layers.79.fc1.weight", "decoder.layers.79.fc1.bias", "decoder.layers.79.fc2.weight", "decoder.layers.79.fc2.bias", "decoder.layers.79.final_layer_norm.weight", "decoder.layers.79.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.80.flat_param_0": {"names": ["decoder.layers.80.self_attn.qkv_proj.weight", "decoder.layers.80.self_attn.qkv_proj.bias", "decoder.layers.80.self_attn.out_proj.weight", "decoder.layers.80.self_attn.out_proj.bias", "decoder.layers.80.self_attn_layer_norm.weight", "decoder.layers.80.self_attn_layer_norm.bias", "decoder.layers.80.fc1.weight", "decoder.layers.80.fc1.bias", "decoder.layers.80.fc2.weight", "decoder.layers.80.fc2.bias", "decoder.layers.80.final_layer_norm.weight", "decoder.layers.80.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.81.flat_param_0": {"names": ["decoder.layers.81.self_attn.qkv_proj.weight", "decoder.layers.81.self_attn.qkv_proj.bias", "decoder.layers.81.self_attn.out_proj.weight", "decoder.layers.81.self_attn.out_proj.bias", "decoder.layers.81.self_attn_layer_norm.weight", "decoder.layers.81.self_attn_layer_norm.bias", "decoder.layers.81.fc1.weight", "decoder.layers.81.fc1.bias", "decoder.layers.81.fc2.weight", "decoder.layers.81.fc2.bias", "decoder.layers.81.final_layer_norm.weight", "decoder.layers.81.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.82.flat_param_0": {"names": ["decoder.layers.82.self_attn.qkv_proj.weight", "decoder.layers.82.self_attn.qkv_proj.bias", "decoder.layers.82.self_attn.out_proj.weight", "decoder.layers.82.self_attn.out_proj.bias", "decoder.layers.82.self_attn_layer_norm.weight", "decoder.layers.82.self_attn_layer_norm.bias", "decoder.layers.82.fc1.weight", "decoder.layers.82.fc1.bias", "decoder.layers.82.fc2.weight", "decoder.layers.82.fc2.bias", "decoder.layers.82.final_layer_norm.weight", "decoder.layers.82.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.83.flat_param_0": {"names": ["decoder.layers.83.self_attn.qkv_proj.weight", "decoder.layers.83.self_attn.qkv_proj.bias", "decoder.layers.83.self_attn.out_proj.weight", "decoder.layers.83.self_attn.out_proj.bias", "decoder.layers.83.self_attn_layer_norm.weight", "decoder.layers.83.self_attn_layer_norm.bias", "decoder.layers.83.fc1.weight", "decoder.layers.83.fc1.bias", "decoder.layers.83.fc2.weight", "decoder.layers.83.fc2.bias", "decoder.layers.83.final_layer_norm.weight", "decoder.layers.83.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.84.flat_param_0": {"names": ["decoder.layers.84.self_attn.qkv_proj.weight", "decoder.layers.84.self_attn.qkv_proj.bias", "decoder.layers.84.self_attn.out_proj.weight", "decoder.layers.84.self_attn.out_proj.bias", "decoder.layers.84.self_attn_layer_norm.weight", "decoder.layers.84.self_attn_layer_norm.bias", "decoder.layers.84.fc1.weight", "decoder.layers.84.fc1.bias", "decoder.layers.84.fc2.weight", "decoder.layers.84.fc2.bias", "decoder.layers.84.final_layer_norm.weight", "decoder.layers.84.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.85.flat_param_0": {"names": ["decoder.layers.85.self_attn.qkv_proj.weight", "decoder.layers.85.self_attn.qkv_proj.bias", "decoder.layers.85.self_attn.out_proj.weight", "decoder.layers.85.self_attn.out_proj.bias", "decoder.layers.85.self_attn_layer_norm.weight", "decoder.layers.85.self_attn_layer_norm.bias", "decoder.layers.85.fc1.weight", "decoder.layers.85.fc1.bias", "decoder.layers.85.fc2.weight", "decoder.layers.85.fc2.bias", "decoder.layers.85.final_layer_norm.weight", "decoder.layers.85.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.86.flat_param_0": {"names": ["decoder.layers.86.self_attn.qkv_proj.weight", "decoder.layers.86.self_attn.qkv_proj.bias", "decoder.layers.86.self_attn.out_proj.weight", "decoder.layers.86.self_attn.out_proj.bias", "decoder.layers.86.self_attn_layer_norm.weight", "decoder.layers.86.self_attn_layer_norm.bias", "decoder.layers.86.fc1.weight", "decoder.layers.86.fc1.bias", "decoder.layers.86.fc2.weight", "decoder.layers.86.fc2.bias", "decoder.layers.86.final_layer_norm.weight", "decoder.layers.86.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.87.flat_param_0": {"names": ["decoder.layers.87.self_attn.qkv_proj.weight", "decoder.layers.87.self_attn.qkv_proj.bias", "decoder.layers.87.self_attn.out_proj.weight", "decoder.layers.87.self_attn.out_proj.bias", "decoder.layers.87.self_attn_layer_norm.weight", "decoder.layers.87.self_attn_layer_norm.bias", "decoder.layers.87.fc1.weight", "decoder.layers.87.fc1.bias", "decoder.layers.87.fc2.weight", "decoder.layers.87.fc2.bias", "decoder.layers.87.final_layer_norm.weight", "decoder.layers.87.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.88.flat_param_0": {"names": ["decoder.layers.88.self_attn.qkv_proj.weight", "decoder.layers.88.self_attn.qkv_proj.bias", "decoder.layers.88.self_attn.out_proj.weight", "decoder.layers.88.self_attn.out_proj.bias", "decoder.layers.88.self_attn_layer_norm.weight", "decoder.layers.88.self_attn_layer_norm.bias", "decoder.layers.88.fc1.weight", "decoder.layers.88.fc1.bias", "decoder.layers.88.fc2.weight", "decoder.layers.88.fc2.bias", "decoder.layers.88.final_layer_norm.weight", "decoder.layers.88.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.89.flat_param_0": {"names": ["decoder.layers.89.self_attn.qkv_proj.weight", "decoder.layers.89.self_attn.qkv_proj.bias", "decoder.layers.89.self_attn.out_proj.weight", "decoder.layers.89.self_attn.out_proj.bias", "decoder.layers.89.self_attn_layer_norm.weight", "decoder.layers.89.self_attn_layer_norm.bias", "decoder.layers.89.fc1.weight", "decoder.layers.89.fc1.bias", "decoder.layers.89.fc2.weight", "decoder.layers.89.fc2.bias", "decoder.layers.89.final_layer_norm.weight", "decoder.layers.89.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.90.flat_param_0": {"names": ["decoder.layers.90.self_attn.qkv_proj.weight", "decoder.layers.90.self_attn.qkv_proj.bias", "decoder.layers.90.self_attn.out_proj.weight", "decoder.layers.90.self_attn.out_proj.bias", "decoder.layers.90.self_attn_layer_norm.weight", "decoder.layers.90.self_attn_layer_norm.bias", "decoder.layers.90.fc1.weight", "decoder.layers.90.fc1.bias", "decoder.layers.90.fc2.weight", "decoder.layers.90.fc2.bias", "decoder.layers.90.final_layer_norm.weight", "decoder.layers.90.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.91.flat_param_0": {"names": ["decoder.layers.91.self_attn.qkv_proj.weight", "decoder.layers.91.self_attn.qkv_proj.bias", "decoder.layers.91.self_attn.out_proj.weight", "decoder.layers.91.self_attn.out_proj.bias", "decoder.layers.91.self_attn_layer_norm.weight", "decoder.layers.91.self_attn_layer_norm.bias", "decoder.layers.91.fc1.weight", "decoder.layers.91.fc1.bias", "decoder.layers.91.fc2.weight", "decoder.layers.91.fc2.bias", "decoder.layers.91.final_layer_norm.weight", "decoder.layers.91.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.92.flat_param_0": {"names": ["decoder.layers.92.self_attn.qkv_proj.weight", "decoder.layers.92.self_attn.qkv_proj.bias", "decoder.layers.92.self_attn.out_proj.weight", "decoder.layers.92.self_attn.out_proj.bias", "decoder.layers.92.self_attn_layer_norm.weight", "decoder.layers.92.self_attn_layer_norm.bias", "decoder.layers.92.fc1.weight", "decoder.layers.92.fc1.bias", "decoder.layers.92.fc2.weight", "decoder.layers.92.fc2.bias", "decoder.layers.92.final_layer_norm.weight", "decoder.layers.92.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.93.flat_param_0": {"names": ["decoder.layers.93.self_attn.qkv_proj.weight", "decoder.layers.93.self_attn.qkv_proj.bias", "decoder.layers.93.self_attn.out_proj.weight", "decoder.layers.93.self_attn.out_proj.bias", "decoder.layers.93.self_attn_layer_norm.weight", "decoder.layers.93.self_attn_layer_norm.bias", "decoder.layers.93.fc1.weight", "decoder.layers.93.fc1.bias", "decoder.layers.93.fc2.weight", "decoder.layers.93.fc2.bias", "decoder.layers.93.final_layer_norm.weight", "decoder.layers.93.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.94.flat_param_0": {"names": ["decoder.layers.94.self_attn.qkv_proj.weight", "decoder.layers.94.self_attn.qkv_proj.bias", "decoder.layers.94.self_attn.out_proj.weight", "decoder.layers.94.self_attn.out_proj.bias", "decoder.layers.94.self_attn_layer_norm.weight", "decoder.layers.94.self_attn_layer_norm.bias", "decoder.layers.94.fc1.weight", "decoder.layers.94.fc1.bias", "decoder.layers.94.fc2.weight", "decoder.layers.94.fc2.bias", "decoder.layers.94.final_layer_norm.weight", "decoder.layers.94.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.95.flat_param_0": {"names": ["decoder.layers.95.self_attn.qkv_proj.weight", "decoder.layers.95.self_attn.qkv_proj.bias", "decoder.layers.95.self_attn.out_proj.weight", "decoder.layers.95.self_attn.out_proj.bias", "decoder.layers.95.self_attn_layer_norm.weight", "decoder.layers.95.self_attn_layer_norm.bias", "decoder.layers.95.fc1.weight", "decoder.layers.95.fc1.bias", "decoder.layers.95.fc2.weight", "decoder.layers.95.fc2.bias", "decoder.layers.95.final_layer_norm.weight", "decoder.layers.95.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}} \ No newline at end of file diff --git a/examples/tutorial/handson5/inference/script/process-opt-175b/unflat.sh b/examples/tutorial/handson5/inference/script/process-opt-175b/unflat.sh new file mode 100644 index 000000000..cc5c190e2 --- /dev/null +++ b/examples/tutorial/handson5/inference/script/process-opt-175b/unflat.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +for i in $(seq 0 7); do + python convert_ckpt.py $1 $2 ${i} & +done + +wait $(jobs -p) diff --git a/examples/tutorial/handson5/inference/script/processing_ckpt_66b.py b/examples/tutorial/handson5/inference/script/processing_ckpt_66b.py new file mode 100644 index 000000000..0494647d7 --- /dev/null +++ b/examples/tutorial/handson5/inference/script/processing_ckpt_66b.py @@ -0,0 +1,55 @@ +import os +import torch +from multiprocessing import Pool + +# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main +# you can use whether wget or git lfs + +path = "/path/to/your/ckpt" +new_path = "/path/to/the/processed/ckpt/" + +assert os.path.isdir(path) +files = [] +for filename in os.listdir(path): + filepath = os.path.join(path, filename) + if os.path.isfile(filepath): + files.append(filepath) + +with Pool(14) as pool: + ckpts = pool.map(torch.load, files) + +restored = {} +for ckpt in ckpts: + for k,v in ckpt.items(): + if(k[0] == 'm'): + k = k[6:] + if(k == "lm_head.weight"): + k = "head.dense.weight" + if(k == "decoder.final_layer_norm.weight"): + k = "decoder.layer_norm.weight" + if(k == "decoder.final_layer_norm.bias"): + k = "decoder.layer_norm.bias" + restored[k] = v +restored["decoder.version"] = "0.0" + + +split_num = len(restored.keys()) // 60 +count = 0 +file_count = 1 +tmp = {} +for k,v in restored.items(): + print(k) + tmp[k] = v + count = count + 1 + if(count == split_num): + filename = str(file_count) + "-restored.pt" + torch.save(tmp, os.path.join(new_path, filename)) + file_count = file_count + 1 + count = 0 + tmp = {} + +filename = str(file_count) + "-restored.pt" +torch.save(tmp, os.path.join(new_path, filename)) + + + diff --git a/examples/tutorial/handson5/opt/README.md b/examples/tutorial/handson5/opt/README.md new file mode 100644 index 000000000..4ed0bf3ab --- /dev/null +++ b/examples/tutorial/handson5/opt/README.md @@ -0,0 +1,53 @@ + +# Train OPT model with Colossal-AI + +## OPT +Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. + +We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before +the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). + +## Our Modifications +We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. + +## Quick Start +You can launch training by using the following bash script + +```bash +bash ./run_clm.sh +``` + +- batch-size-per-gpu: number of samples fed to each GPU, default is 16 +- mem-cap: limit memory usage within a value in GB, default is 0 (no limit) +- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request +the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT). +- gpu-num: the number of GPUs to use, default is 1. + +## Remarkable Performance +On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed. +Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale. + +

+ +

+ +Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI! + +More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d), +and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon. diff --git a/examples/tutorial/handson5/opt/benchmark.sh b/examples/tutorial/handson5/opt/benchmark.sh new file mode 100644 index 000000000..f02f7629a --- /dev/null +++ b/examples/tutorial/handson5/opt/benchmark.sh @@ -0,0 +1,21 @@ +export BS=16 +export MEMCAP=0 +export MODEL="6.7b" +export GPUNUM=1 + +for MODEL in "6.7b" "13b" "1.3b" +do +for GPUNUM in 8 1 +do +for BS in 16 24 32 8 +do +for MEMCAP in 0 40 +do +pkill -9 torchrun +pkill -9 python + +bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM +done +done +done +done diff --git a/examples/tutorial/handson5/opt/colossalai_zero.py b/examples/tutorial/handson5/opt/colossalai_zero.py new file mode 100644 index 000000000..833745f3e --- /dev/null +++ b/examples/tutorial/handson5/opt/colossalai_zero.py @@ -0,0 +1,6 @@ +from colossalai.zero.shard_utils import TensorShardStrategy + +zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), + tensor_placement_policy="auto", + reuse_fp16_shard=True), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) diff --git a/examples/tutorial/handson5/opt/context.py b/examples/tutorial/handson5/opt/context.py new file mode 100644 index 000000000..95f0abf1d --- /dev/null +++ b/examples/tutorial/handson5/opt/context.py @@ -0,0 +1,32 @@ +import torch.distributed as dist + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + + +class barrier_context(): + """ + This context manager is used to allow one process to execute while blocking all + other processes in the same process group. This is often useful when downloading is required + as we only want to download in one process to prevent file corruption. + Args: + executor_rank (int): the process rank to execute without blocking, all other processes will be blocked + parallel_mode (ParallelMode): the parallel mode corresponding to a process group + Usage: + with barrier_context(): + dataset = CIFAR10(root='./data', download=True) + """ + + def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL): + # the class name is lowercase by convention + current_rank = gpc.get_local_rank(parallel_mode=parallel_mode) + self.should_block = current_rank != executor_rank + self.group = gpc.get_group(parallel_mode=parallel_mode) + + def __enter__(self): + if self.should_block: + dist.barrier(group=self.group) + + def __exit__(self, exc_type, exc_value, exc_traceback): + if not self.should_block: + dist.barrier(group=self.group) diff --git a/examples/tutorial/handson5/opt/requirements.txt b/examples/tutorial/handson5/opt/requirements.txt new file mode 100644 index 000000000..c34df7992 --- /dev/null +++ b/examples/tutorial/handson5/opt/requirements.txt @@ -0,0 +1,6 @@ +colossalai +torch >= 1.8.1 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +accelerate == 0.13.2 diff --git a/examples/tutorial/handson5/opt/run_clm.py b/examples/tutorial/handson5/opt/run_clm.py new file mode 100755 index 000000000..00e05459a --- /dev/null +++ b/examples/tutorial/handson5/opt/run_clm.py @@ -0,0 +1,596 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import math +import os +import time +from itertools import chain + +import datasets +import torch +import torch.distributed as dist +from accelerate.utils import set_seed +from context import barrier_context +from datasets import load_dataset +from packaging import version +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +import colossalai +import transformers +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoTokenizer, + GPT2Tokenizer, + OPTForCausalLM, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils.versions import require_version + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def get_time_stamp(): + torch.cuda.synchronize() + return time.time() + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument("--train_file", + type=str, + default=None, + help="A csv or a json file containing the training data.") + parser.add_argument("--validation_file", + type=str, + default=None, + help="A csv or a json file containing the validation data.") + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument("--num_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)."), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument("--overwrite_cache", + type=bool, + default=False, + help="Overwrite the cached training and evaluation sets") + parser.add_argument("--no_keep_linebreaks", + action="store_true", + help="Do not keep line breaks when using TXT files.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_model_id", + type=str, + help="The name of the repository to keep in sync with the local `output_dir`.") + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed."), + ) + + parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") + parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print("Using {} GB of GPU memory".format(size_in_GB)) + + +def main(): + args = parse_args() + disable_existing_loggers() + colossalai.launch_from_torch(config=dict()) + logger = get_dist_logger() + is_main_process = dist.get_rank() == 0 + + if is_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") + + # Handle the repository creation + with barrier_context(): + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + logger.info("Start preparing dataset", ranks=[0]) + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + logger.info("Dataset is prepared", ranks=[0]) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + logger.info("Model config has been created", ranks=[0]) + + if args.model_name_or_path == 'facebook/opt-13b': + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) + else: + print(f'load model from {args.model_name_or_path}') + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) + + if args.init_in_cpu: + init_dev = torch.device('cpu') + else: + init_dev = get_current_device() + + # build model + if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + # currently, there has a bug in pretrained opt-13b + # we can not import it until huggingface fix it + logger.info("Train a new model from scratch", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM(config) + else: + logger.info("Finetune a pre-trained model", ranks=[0]) + with ColoInitContext(device=init_dev): + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False) + + # enable graident checkpointing + model.gradient_checkpointing_enable() + + PLACEMENT_POLICY = 'auto' + cai_version = colossalai.__version__ + logger.info(f'using Colossal-AI version {cai_version}') + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + pg = ProcessGroup() + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) + model = ZeroDDP(model, gemini_manager) + + logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx.") + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i:i + block_size] for i in range(0, total_length, block_size) + ] for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + # for index in random.sample(range(len(train_dataset)), 3): + # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = get_dataloader(train_dataset, + shuffle=True, + add_sampler=True, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size) + eval_dataloader = DataLoader(eval_dataset, + collate_fn=default_data_collator, + batch_size=args.per_device_eval_batch_size) + logger.info("Dataloaders have been created", ranks=[0]) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) + optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA) + + logger.info("***** Running training *****", ranks=[0]) + logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) + logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0]) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) + logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) + completed_steps = 0 + starting_epoch = 0 + global_step = 0 + + for epoch in range(starting_epoch, args.num_train_epochs): + + if completed_steps >= args.max_train_steps: + break + + model.train() + for step, batch in enumerate(train_dataloader): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + loss = outputs['loss'] + optimizer.backward(loss) + + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + global_step += 1 + logger.info("Global step {} finished".format(global_step + 1), ranks=[0]) + + if completed_steps >= args.max_train_steps: + break + + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + batch = {k: v.cuda() for k, v in batch.items()} + outputs = model(**batch) + + loss = outputs['loss'].unsqueeze(0) + losses.append(loss) + + losses = torch.cat(losses) + losses = losses[:len(eval_dataset)] + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) + + if args.output_dir is not None: + model_state = model.state_dict() + if is_main_process: + torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + dist.barrier() + # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + # model.load_state_dict(load_state, strict=False) + + logger.info("Training finished", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorial/handson5/opt/run_clm.sh b/examples/tutorial/handson5/opt/run_clm.sh new file mode 100644 index 000000000..858d3325a --- /dev/null +++ b/examples/tutorial/handson5/opt/run_clm.sh @@ -0,0 +1,22 @@ +set -x +export BS=${1:-16} +export MEMCAP=${2:-0} +export MODEL=${3:-"125m"} +export GPUNUM=${4:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/tutorial/handson5/zero/README.md b/examples/tutorial/handson5/zero/README.md new file mode 100644 index 000000000..1af7f7cdc --- /dev/null +++ b/examples/tutorial/handson5/zero/README.md @@ -0,0 +1,16 @@ +## Overview +This example shows how to use ColossalAI to run huggingface GPT training with Gemini and ZeRO DDP. + +## GPT +We use the huggingface transformers GPT2 model. The input data is randonly generated. + +## Our Modifications +We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. + +## Quick Start +You can launch training by using the following bash script + +```bash +pip install -r requirements.txt +bash run.sh +``` diff --git a/examples/tutorial/handson5/zero/requirements.txt b/examples/tutorial/handson5/zero/requirements.txt new file mode 100644 index 000000000..208a31ebb --- /dev/null +++ b/examples/tutorial/handson5/zero/requirements.txt @@ -0,0 +1,3 @@ +colossalai >= 0.1.10 +torch >= 1.8.1 +transformers >= 4.231 diff --git a/examples/tutorial/handson5/zero/run.sh b/examples/tutorial/handson5/zero/run.sh new file mode 100644 index 000000000..1ff2a4eed --- /dev/null +++ b/examples/tutorial/handson5/zero/run.sh @@ -0,0 +1 @@ +env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log diff --git a/examples/tutorial/handson5/zero/train_gpt_demo.py b/examples/tutorial/handson5/zero/train_gpt_demo.py new file mode 100644 index 000000000..cdf7c41b2 --- /dev/null +++ b/examples/tutorial/handson5/zero/train_gpt_demo.py @@ -0,0 +1,241 @@ +from functools import partial +from time import time + +import psutil +import torch +import torch.nn as nn +from packaging import version + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer +from transformers import GPT2Config, GPT2LMHeadModel + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini.", + ) + args = parser.parse_args() + return args + + +## Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + if param.process_group.tp_world_size() == 1: + param.set_process_group(pg) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + + +## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +## Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def gpt2_medium(checkpoint=False): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +def gpt2_xl(checkpoint=True): + return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint) + + +def gpt2_10b(checkpoint=True): + return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint) + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +# Tensor Parallel +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + """tensor_parallelize + Sharding the Model Parameters. + + Args: + model (torch.nn.Module): a torch module to be sharded + """ + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + # set process group for all parameters + param.set_process_group(pg) + + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) # colmn slice + # keep the shape of the output from c_fc + param.compute_spec.set_output_replicate(False) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) # row slice + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) # colmn slice + + +# Gemini + ZeRO DDP +def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): + cai_version = colossalai.__version__ + if version.parse(cai_version) > version.parse("0.1.10"): + from colossalai.nn.parallel import GeminiDDP + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=32) + elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): + from colossalai.gemini import ChunkManager, GeminiManager + chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) + gemini_manager = GeminiManager(placememt_policy, chunk_manager) + chunk_manager = ChunkManager(chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(placememt_policy)) + model = ZeroDDP(model, gemini_manager) + else: + raise NotImplemented(f"CAI version {cai_version} is not supported") + return model + + +def main(): + args = parse_args() + + BATCH_SIZE = 8 + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + NUM_STEPS = 10 + + disable_existing_loggers() + colossalai.launch_from_torch(config={}) + + pg = ProcessGroup(tp_degree=args.tp_degree) + + logger = get_dist_logger() + logger.info(get_mem_info(), ranks=[0]) + + # build GPT model + with ColoInitContext(device=get_current_device()): + model = gpt2_medium(checkpoint=True) + + numel = sum([p.numel() for p in model.parameters()]) + logger.info(f'Model numel: {numel}', ranks=[0]) + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + # Tensor Parallelism (TP) + tensor_parallelize(model, pg) + # Gemini + ZeRO DP, Note it must be used after TP + model = gemini_zero_dpp(model, pg, args.placement) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + + # build criterion + criterion = GPTLMLoss() + + # build optimizer + optimizer = HybridAdam(model.parameters(), lr=1e-3) + optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) + logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + + torch.cuda.synchronize() + model.train() + for n in range(NUM_STEPS): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + start = time() + outputs = model(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) + optimizer.backward(loss) + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) + optimizer.step() + logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + step_time = time() - start + logger.info( + f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', + ranks=[0]) + + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/diffusion/LICENSE b/examples/tutorial/handson6/LICENSE similarity index 100% rename from examples/tutorial/diffusion/LICENSE rename to examples/tutorial/handson6/LICENSE diff --git a/examples/tutorial/diffusion/README.md b/examples/tutorial/handson6/README.md similarity index 99% rename from examples/tutorial/diffusion/README.md rename to examples/tutorial/handson6/README.md index 38878ab71..a5256600d 100644 --- a/examples/tutorial/diffusion/README.md +++ b/examples/tutorial/handson6/README.md @@ -1,4 +1,5 @@ -# Stable Diffusion with Colossal-AI +# Handson 6: Acceleration of Stable Diffusion + *[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).* diff --git a/examples/tutorial/diffusion/configs/train_colossalai.yaml b/examples/tutorial/handson6/configs/train_colossalai.yaml similarity index 100% rename from examples/tutorial/diffusion/configs/train_colossalai.yaml rename to examples/tutorial/handson6/configs/train_colossalai.yaml diff --git a/examples/tutorial/diffusion/configs/train_ddp.yaml b/examples/tutorial/handson6/configs/train_ddp.yaml similarity index 100% rename from examples/tutorial/diffusion/configs/train_ddp.yaml rename to examples/tutorial/handson6/configs/train_ddp.yaml diff --git a/examples/tutorial/diffusion/configs/train_pokemon.yaml b/examples/tutorial/handson6/configs/train_pokemon.yaml similarity index 100% rename from examples/tutorial/diffusion/configs/train_pokemon.yaml rename to examples/tutorial/handson6/configs/train_pokemon.yaml diff --git a/examples/tutorial/diffusion/environment.yaml b/examples/tutorial/handson6/environment.yaml similarity index 100% rename from examples/tutorial/diffusion/environment.yaml rename to examples/tutorial/handson6/environment.yaml diff --git a/examples/tutorial/diffusion/ldm/data/__init__.py b/examples/tutorial/handson6/ldm/data/__init__.py similarity index 100% rename from examples/tutorial/diffusion/ldm/data/__init__.py rename to examples/tutorial/handson6/ldm/data/__init__.py diff --git a/examples/tutorial/diffusion/ldm/data/base.py b/examples/tutorial/handson6/ldm/data/base.py similarity index 100% rename from examples/tutorial/diffusion/ldm/data/base.py rename to examples/tutorial/handson6/ldm/data/base.py diff --git a/examples/tutorial/diffusion/ldm/data/imagenet.py b/examples/tutorial/handson6/ldm/data/imagenet.py similarity index 100% rename from examples/tutorial/diffusion/ldm/data/imagenet.py rename to examples/tutorial/handson6/ldm/data/imagenet.py diff --git a/examples/tutorial/diffusion/ldm/data/lsun.py b/examples/tutorial/handson6/ldm/data/lsun.py similarity index 100% rename from examples/tutorial/diffusion/ldm/data/lsun.py rename to examples/tutorial/handson6/ldm/data/lsun.py diff --git a/examples/tutorial/diffusion/ldm/lr_scheduler.py b/examples/tutorial/handson6/ldm/lr_scheduler.py similarity index 100% rename from examples/tutorial/diffusion/ldm/lr_scheduler.py rename to examples/tutorial/handson6/ldm/lr_scheduler.py diff --git a/examples/tutorial/diffusion/ldm/models/autoencoder.py b/examples/tutorial/handson6/ldm/models/autoencoder.py similarity index 100% rename from examples/tutorial/diffusion/ldm/models/autoencoder.py rename to examples/tutorial/handson6/ldm/models/autoencoder.py diff --git a/examples/tutorial/diffusion/ldm/models/diffusion/__init__.py b/examples/tutorial/handson6/ldm/models/diffusion/__init__.py similarity index 100% rename from examples/tutorial/diffusion/ldm/models/diffusion/__init__.py rename to examples/tutorial/handson6/ldm/models/diffusion/__init__.py diff --git a/examples/tutorial/diffusion/ldm/models/diffusion/classifier.py b/examples/tutorial/handson6/ldm/models/diffusion/classifier.py similarity index 100% rename from examples/tutorial/diffusion/ldm/models/diffusion/classifier.py rename to examples/tutorial/handson6/ldm/models/diffusion/classifier.py diff --git a/examples/tutorial/diffusion/ldm/models/diffusion/ddim.py b/examples/tutorial/handson6/ldm/models/diffusion/ddim.py similarity index 100% rename from examples/tutorial/diffusion/ldm/models/diffusion/ddim.py rename to examples/tutorial/handson6/ldm/models/diffusion/ddim.py diff --git a/examples/tutorial/diffusion/ldm/models/diffusion/ddpm.py b/examples/tutorial/handson6/ldm/models/diffusion/ddpm.py similarity index 100% rename from examples/tutorial/diffusion/ldm/models/diffusion/ddpm.py rename to examples/tutorial/handson6/ldm/models/diffusion/ddpm.py diff --git a/examples/tutorial/diffusion/ldm/models/diffusion/plms.py b/examples/tutorial/handson6/ldm/models/diffusion/plms.py similarity index 100% rename from examples/tutorial/diffusion/ldm/models/diffusion/plms.py rename to examples/tutorial/handson6/ldm/models/diffusion/plms.py diff --git a/examples/tutorial/diffusion/ldm/modules/attention.py b/examples/tutorial/handson6/ldm/modules/attention.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/attention.py rename to examples/tutorial/handson6/ldm/modules/attention.py diff --git a/examples/tutorial/diffusion/ldm/modules/diffusionmodules/__init__.py b/examples/tutorial/handson6/ldm/modules/diffusionmodules/__init__.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/diffusionmodules/__init__.py rename to examples/tutorial/handson6/ldm/modules/diffusionmodules/__init__.py diff --git a/examples/tutorial/diffusion/ldm/modules/diffusionmodules/model.py b/examples/tutorial/handson6/ldm/modules/diffusionmodules/model.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/diffusionmodules/model.py rename to examples/tutorial/handson6/ldm/modules/diffusionmodules/model.py diff --git a/examples/tutorial/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/tutorial/handson6/ldm/modules/diffusionmodules/openaimodel.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/diffusionmodules/openaimodel.py rename to examples/tutorial/handson6/ldm/modules/diffusionmodules/openaimodel.py diff --git a/examples/tutorial/diffusion/ldm/modules/diffusionmodules/util.py b/examples/tutorial/handson6/ldm/modules/diffusionmodules/util.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/diffusionmodules/util.py rename to examples/tutorial/handson6/ldm/modules/diffusionmodules/util.py diff --git a/examples/tutorial/diffusion/ldm/modules/distributions/__init__.py b/examples/tutorial/handson6/ldm/modules/distributions/__init__.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/distributions/__init__.py rename to examples/tutorial/handson6/ldm/modules/distributions/__init__.py diff --git a/examples/tutorial/diffusion/ldm/modules/distributions/distributions.py b/examples/tutorial/handson6/ldm/modules/distributions/distributions.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/distributions/distributions.py rename to examples/tutorial/handson6/ldm/modules/distributions/distributions.py diff --git a/examples/tutorial/diffusion/ldm/modules/ema.py b/examples/tutorial/handson6/ldm/modules/ema.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/ema.py rename to examples/tutorial/handson6/ldm/modules/ema.py diff --git a/examples/tutorial/diffusion/ldm/modules/encoders/__init__.py b/examples/tutorial/handson6/ldm/modules/encoders/__init__.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/encoders/__init__.py rename to examples/tutorial/handson6/ldm/modules/encoders/__init__.py diff --git a/examples/tutorial/diffusion/ldm/modules/encoders/modules.py b/examples/tutorial/handson6/ldm/modules/encoders/modules.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/encoders/modules.py rename to examples/tutorial/handson6/ldm/modules/encoders/modules.py diff --git a/examples/tutorial/diffusion/ldm/modules/flash_attention.py b/examples/tutorial/handson6/ldm/modules/flash_attention.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/flash_attention.py rename to examples/tutorial/handson6/ldm/modules/flash_attention.py diff --git a/examples/tutorial/diffusion/ldm/modules/image_degradation/__init__.py b/examples/tutorial/handson6/ldm/modules/image_degradation/__init__.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/image_degradation/__init__.py rename to examples/tutorial/handson6/ldm/modules/image_degradation/__init__.py diff --git a/examples/tutorial/diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/tutorial/handson6/ldm/modules/image_degradation/bsrgan.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/image_degradation/bsrgan.py rename to examples/tutorial/handson6/ldm/modules/image_degradation/bsrgan.py diff --git a/examples/tutorial/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/tutorial/handson6/ldm/modules/image_degradation/bsrgan_light.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/image_degradation/bsrgan_light.py rename to examples/tutorial/handson6/ldm/modules/image_degradation/bsrgan_light.py diff --git a/examples/tutorial/diffusion/ldm/modules/image_degradation/utils/test.png b/examples/tutorial/handson6/ldm/modules/image_degradation/utils/test.png similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/image_degradation/utils/test.png rename to examples/tutorial/handson6/ldm/modules/image_degradation/utils/test.png diff --git a/examples/tutorial/diffusion/ldm/modules/image_degradation/utils_image.py b/examples/tutorial/handson6/ldm/modules/image_degradation/utils_image.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/image_degradation/utils_image.py rename to examples/tutorial/handson6/ldm/modules/image_degradation/utils_image.py diff --git a/examples/tutorial/diffusion/ldm/modules/losses/__init__.py b/examples/tutorial/handson6/ldm/modules/losses/__init__.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/losses/__init__.py rename to examples/tutorial/handson6/ldm/modules/losses/__init__.py diff --git a/examples/tutorial/diffusion/ldm/modules/losses/contperceptual.py b/examples/tutorial/handson6/ldm/modules/losses/contperceptual.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/losses/contperceptual.py rename to examples/tutorial/handson6/ldm/modules/losses/contperceptual.py diff --git a/examples/tutorial/diffusion/ldm/modules/losses/vqperceptual.py b/examples/tutorial/handson6/ldm/modules/losses/vqperceptual.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/losses/vqperceptual.py rename to examples/tutorial/handson6/ldm/modules/losses/vqperceptual.py diff --git a/examples/tutorial/diffusion/ldm/modules/x_transformer.py b/examples/tutorial/handson6/ldm/modules/x_transformer.py similarity index 100% rename from examples/tutorial/diffusion/ldm/modules/x_transformer.py rename to examples/tutorial/handson6/ldm/modules/x_transformer.py diff --git a/examples/tutorial/diffusion/ldm/util.py b/examples/tutorial/handson6/ldm/util.py similarity index 100% rename from examples/tutorial/diffusion/ldm/util.py rename to examples/tutorial/handson6/ldm/util.py diff --git a/examples/tutorial/diffusion/main.py b/examples/tutorial/handson6/main.py similarity index 100% rename from examples/tutorial/diffusion/main.py rename to examples/tutorial/handson6/main.py diff --git a/examples/tutorial/diffusion/requirements.txt b/examples/tutorial/handson6/requirements.txt similarity index 100% rename from examples/tutorial/diffusion/requirements.txt rename to examples/tutorial/handson6/requirements.txt diff --git a/examples/tutorial/diffusion/scripts/download_first_stages.sh b/examples/tutorial/handson6/scripts/download_first_stages.sh similarity index 100% rename from examples/tutorial/diffusion/scripts/download_first_stages.sh rename to examples/tutorial/handson6/scripts/download_first_stages.sh diff --git a/examples/tutorial/diffusion/scripts/download_models.sh b/examples/tutorial/handson6/scripts/download_models.sh similarity index 100% rename from examples/tutorial/diffusion/scripts/download_models.sh rename to examples/tutorial/handson6/scripts/download_models.sh diff --git a/examples/tutorial/diffusion/scripts/img2img.py b/examples/tutorial/handson6/scripts/img2img.py similarity index 100% rename from examples/tutorial/diffusion/scripts/img2img.py rename to examples/tutorial/handson6/scripts/img2img.py diff --git a/examples/tutorial/diffusion/scripts/inpaint.py b/examples/tutorial/handson6/scripts/inpaint.py similarity index 100% rename from examples/tutorial/diffusion/scripts/inpaint.py rename to examples/tutorial/handson6/scripts/inpaint.py diff --git a/examples/tutorial/diffusion/scripts/knn2img.py b/examples/tutorial/handson6/scripts/knn2img.py similarity index 100% rename from examples/tutorial/diffusion/scripts/knn2img.py rename to examples/tutorial/handson6/scripts/knn2img.py diff --git a/examples/tutorial/diffusion/scripts/sample_diffusion.py b/examples/tutorial/handson6/scripts/sample_diffusion.py similarity index 100% rename from examples/tutorial/diffusion/scripts/sample_diffusion.py rename to examples/tutorial/handson6/scripts/sample_diffusion.py diff --git a/examples/tutorial/diffusion/scripts/tests/test_checkpoint.py b/examples/tutorial/handson6/scripts/tests/test_checkpoint.py similarity index 100% rename from examples/tutorial/diffusion/scripts/tests/test_checkpoint.py rename to examples/tutorial/handson6/scripts/tests/test_checkpoint.py diff --git a/examples/tutorial/diffusion/scripts/tests/test_watermark.py b/examples/tutorial/handson6/scripts/tests/test_watermark.py similarity index 100% rename from examples/tutorial/diffusion/scripts/tests/test_watermark.py rename to examples/tutorial/handson6/scripts/tests/test_watermark.py diff --git a/examples/tutorial/diffusion/scripts/train_searcher.py b/examples/tutorial/handson6/scripts/train_searcher.py similarity index 100% rename from examples/tutorial/diffusion/scripts/train_searcher.py rename to examples/tutorial/handson6/scripts/train_searcher.py diff --git a/examples/tutorial/diffusion/scripts/txt2img.py b/examples/tutorial/handson6/scripts/txt2img.py similarity index 100% rename from examples/tutorial/diffusion/scripts/txt2img.py rename to examples/tutorial/handson6/scripts/txt2img.py diff --git a/examples/tutorial/diffusion/setup.py b/examples/tutorial/handson6/setup.py similarity index 100% rename from examples/tutorial/diffusion/setup.py rename to examples/tutorial/handson6/setup.py diff --git a/examples/tutorial/diffusion/train.sh b/examples/tutorial/handson6/train.sh similarity index 100% rename from examples/tutorial/diffusion/train.sh rename to examples/tutorial/handson6/train.sh