ColossalAI/tests/test_checkpoint_io/test_gemini_checkpoint_io.py

228 lines
8.1 KiB
Python
Raw Normal View History

import os
import pytest
import torch
import torch.distributed as dist
from transformers import LlamaForCausalLM
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
MODEL_PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.5},
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
]
OPTIM_PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half
[gemini] improve compatibility and add static placement policy (#4479) * [gemini] remove distributed-related part from colotensor (#4379) * [gemini] remove process group dependency * [gemini] remove tp part from colo tensor * [gemini] patch inplace op * [gemini] fix param op hook and update tests * [test] remove useless tests * [test] remove useless tests * [misc] fix requirements * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [test] fix model zoo * [misc] update requirements * [gemini] refactor gemini optimizer and gemini ddp (#4398) * [gemini] update optimizer interface * [gemini] renaming gemini optimizer * [gemini] refactor gemini ddp class * [example] update gemini related example * [example] update gemini related example * [plugin] fix gemini plugin args * [test] update gemini ckpt tests * [gemini] fix checkpoint io * [example] fix opt example requirements * [example] fix opt example * [example] fix opt example * [example] fix opt example * [gemini] add static placement policy (#4443) * [gemini] add static placement policy * [gemini] fix param offload * [test] update gemini tests * [plugin] update gemini plugin * [plugin] update gemini plugin docstr * [misc] fix flash attn requirement * [test] fix gemini checkpoint io test * [example] update resnet example result (#4457) * [example] update bert example result (#4458) * [doc] update gemini doc (#4468) * [example] update gemini related examples (#4473) * [example] update gpt example * [example] update dreambooth example * [example] update vit * [example] update opt * [example] update palm * [example] update vit and opt benchmark * [hotfix] fix bert in model zoo (#4480) * [hotfix] fix bert in model zoo * [test] remove chatglm gemini test * [test] remove sam gemini test * [test] remove vit gemini test * [hotfix] fix opt tutorial example (#4497) * [hotfix] fix opt tutorial example * [hotfix] fix opt tutorial example
2023-08-24 01:29:25 +00:00
]
@clear_cache_before_run()
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
2024-11-18 07:06:04 +00:00
def exam_state_dict_with_origin(
placement_config,
model_name,
use_safetensors: bool,
tp_size: int,
zero_size: int,
):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
enable_flash_attention = True if tp_size > 1 else False
enable_fused_normalization = True if tp_size > 1 else False
enable_jit_fused = True if tp_size > 1 else False
with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path)
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(
**placement_config,
tp_size=tp_size,
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
enable_flash_attention=enable_flash_attention,
enable_fused_normalization=enable_fused_normalization,
enable_jit_fused=enable_jit_fused,
extra_dp_size=extra_dp_size,
)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
booster.save_model(
bert_model,
pretrained_path,
True,
True,
"",
(model_size / 3),
use_safetensors=use_safetensors,
)
2024-11-18 07:06:04 +00:00
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
@clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
2024-11-18 07:06:04 +00:00
@parameterize("shard", [False])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
2024-11-18 07:06:04 +00:00
@parameterize(
"use_async",
[
True,
],
)
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
enable_flash_attention = True if tp_size > 1 else False
enable_fused_normalization = True if tp_size > 1 else False
enable_jit_fused = True if tp_size > 1 else False
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(
**placement_config,
precision="fp16",
initial_scale=(2**14),
tp_size=tp_size,
extra_dp_size=extra_dp_size,
[shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
2024-04-03 09:15:47 +00:00
enable_flash_attention=enable_flash_attention,
enable_fused_normalization=enable_fused_normalization,
enable_jit_fused=enable_jit_fused,
)
booster = Booster(plugin=plugin)
model = model_fn()
new_model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
data = data_gen_fn()
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
2024-11-18 07:06:04 +00:00
"""output_dir = "./checkpoints"
import os
os.makedirs(output_dir, exist_ok=True)
model_ckpt_path = f"{output_dir}/model"
optimizer_ckpt_path = f"{output_dir}/optimizer"
if not shard:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
print("model_ckpt_path", model_ckpt_path)
booster.save_model(
model,
model_ckpt_path,
shard=shard,
size_per_shard=size_per_shard,
2024-11-18 07:06:04 +00:00
use_async=use_async
)
2024-11-18 07:06:04 +00:00
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)"""
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
2024-11-18 07:06:04 +00:00
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1
# Check the new model/optimizer can successfully run.
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
output = new_model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, new_optimizer)
new_optimizer.step()
booster.save_model(new_model, model_ckpt_path, shard=shard)
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
def exam_lazy_from_pretrained():
llama_path = os.environ["LLAMA_PATH"]
plugin = GeminiPlugin()
booster = Booster(plugin=plugin)
orig_model = LlamaForCausalLM.from_pretrained(llama_path)
orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
with LazyInitContext():
model = LlamaForCausalLM.from_pretrained(llama_path)
model, *_ = booster.boost(model)
with shared_tempdir() as tempdir:
save_path = os.path.join(tempdir, "model.pt")
booster.save_model(model, save_path, shard=False)
dist.barrier()
state_dict = torch.load(save_path, map_location="cpu")
check_state_dict_equal(state_dict, orig_state_dict, ignore_dtype=True)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
exam_state_dict_with_origin()
2024-11-18 07:06:04 +00:00
# exam_lazy_from_pretrained()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_gemini_ckpIO():
spawn(run_dist, 4)
if __name__ == "__main__":
test_gemini_ckpIO()