[misc] Accelerate CI for zero and dist optim (#5758)

* remove fp16 from lamb

* remove d2h copy in checking states

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5782/head
Edenzzzz 6 months ago committed by GitHub
parent 50b4c8e8cf
commit 79f7a7b211
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -22,8 +22,6 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1
b,
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}",
)

@ -72,7 +72,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
@clear_cache_before_run()
@ -130,13 +130,11 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1
@ -169,7 +167,7 @@ def exam_lazy_from_pretrained():
booster.save_model(model, save_path, shard=False)
dist.barrier()
state_dict = torch.load(save_path, map_location="cpu")
check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
check_state_dict_equal(state_dict, orig_state_dict, ignore_dtype=True)
def run_dist(rank, world_size, port):

@ -62,12 +62,12 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
check_state_dict_equal(
model.state_dict(only_rank_0=False, prefix="module.module."),
new_model.state_dict(),
False,
ignore_device=False,
ignore_dtype=True,
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), ignore_device=False)
# Check the new model/optimizer can successfully run.
data = data_gen_fn()
@ -128,7 +128,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
check_state_dict_equal(
new_model.state_dict(only_rank_0=False, prefix="module.module."),
model.state_dict(),
False,
ignore_device=False,
ignore_dtype=True,
)
@ -145,7 +145,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
k in old_group and k in new_group
), f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
assert old_group[k] == new_group[k]
check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], False)
check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], ignore_device=False)
# Check the new model/optimizer can successfully run.
data = data_gen_fn()

@ -94,9 +94,9 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict())
dist.barrier()
# Check whether the loaded model & optimizer works smoothly.

@ -55,7 +55,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
check_state_dict_equal(model.state_dict(), new_model.state_dict())
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
@ -70,7 +70,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
torch.cuda.empty_cache()
@ -110,7 +110,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
check_state_dict_equal(model.state_dict(), new_model.state_dict())
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
@ -126,7 +126,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
except Exception as e:
# return repr(e)

@ -61,9 +61,9 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
if plugin_type == "gemini":
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False))
else:
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
dist.barrier()

@ -52,12 +52,12 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
check_state_dict_equal(model.state_dict(), new_model.state_dict())
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())
def run_dist(rank, world_size, port):

@ -3,7 +3,6 @@ import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, spawn
@ -119,8 +118,12 @@ def run_bert_test(test_config, optim_class, sharded_optim_class):
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**15 # avoid overflow
target_models = [
"transformers_bert",
]
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in target_models:
check_bert_fwd_bwd(
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
)
@ -152,7 +155,8 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
shard_spec = sharded_optimizer.shard_spec_dict[id(tp)]
use_zero = sharded_optimizer.use_zero
tp_optim_state = tp_state[key]
p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape
state = p_state[key]
dp_size, tp_size = (
sharded_optimizer.dp_size,
sharded_optimizer.tp_size,
@ -165,88 +169,54 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
if shard_spec.sharding_sequence[0] == "R":
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
pass
if key == "exp_avg_sq_row":
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
# gather from tp group
# sq_row don need gather alone tp group
if key == "exp_avg_sq_row":
pass
# sq_col need gather alone dp group
# sq_col need gather alone tp group
if key == "exp_avg_sq_col":
tp_optim_state = _gather(
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
)
tp_optim_state.shape
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
# row parallel
if shard_spec.sharding_sequence[-1] == "R":
if use_zero:
elif shard_spec.sharding_sequence[-1] == "R":
# TODO: this case may cause shape mismatch @duanjunwen
if use_zero and key == "exp_avg_sq_row" and state.shape[0] // tp_size % dp_size == 0:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
if p_state[key].shape[0] // tp_size % dp_size != 0:
pass
else:
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
pass
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
# gather from tp group
# sq_row need gather alone tp group
if key == "exp_avg_sq_row":
tp_optim_state = _gather(
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
)
tp_optim_state.shape
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
return
else:
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
# row residule; no gather
if p_state[key].shape[0] % dp_size != 0:
if state.shape[0] % dp_size != 0:
pass
else:
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
tp_optim_state = tp_optim_state.div_(dp_size)
# need a div;
else:
pass
# Sovled a New issus: different dtype;
# So far, only happen in H100 env;
# Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision;
# Or assert_close just update to check dtype;
if p_state[key].dtype != tp_optim_state.dtype:
tp_optim_state = tp_optim_state.type(p_state[key].dtype)
try:
assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2)
except:
pass
if state.dtype != tp_optim_state.dtype:
tp_optim_state = tp_optim_state.type(state.dtype)
# TODO: some sharding checks are currently buggy, but the state values should match
# @duanjunwen
if state.shape != tp_optim_state.shape:
return
assert_close(state, tp_optim_state, atol=5e-4, rtol=1.6e-2)
def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol):

@ -7,14 +7,11 @@ from torch import nn
from torch.testing import assert_close
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer.adafactor import Adafactor
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor import (
distribute_tensor,
@ -59,7 +56,6 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc
rtol = 4e-3
atol = 4e-3
# return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol))
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
@ -194,7 +190,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
# Col Parallel
# ==============================
weight_col_shard = shard_colwise(weight.clone(), tp_group)
weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape
weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec
weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True))
bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True))
@ -203,17 +198,12 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
# Row Parallel
# ==============================
weight_row_shard = shard_rowwise(weight.clone(), tp_group)
weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape
weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec
weight_row_shard_flatten = nn.Parameter(
weight_row_shard.clone().flatten().requires_grad_(True)
) # flatten input(not dtensor) to optimizer
bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True))
# base_param_group = setup_param_groups([weight, bias])
# cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten])
# rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten])
# ==============================
# Init Optimizer
# ==============================
@ -267,19 +257,11 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
bias_row_flatten.grad = bias.grad.clone().flatten()
rp_dist_optim.step()
# gather result
weight_col_gather = _gather(
input_=weight_col_shard_flatten.data.view(-1, H // tp_size),
dim=-1,
process_group=tp_group,
) # gather
weight_row_gather = _gather(input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group).view(
-1, W
) # gather
weight_row_chunk = weight.t().reshape(-1, W).chunk(tp_size, dim=-1)[dist.get_rank(tp_group)].flatten()
weight_col_chunk = weight.reshape(-1, H).chunk(tp_size, dim=-1)[dist.get_rank(tp_group)].flatten()
# verify
correctness_verify(weight.data, weight_col_gather.data, dtype)
correctness_verify(weight.data, weight_row_gather.data, dtype)
correctness_verify(weight_col_chunk, weight_col_shard_flatten, dtype)
correctness_verify(weight_row_chunk, weight_row_shard_flatten, dtype)
print(f"Base Test Passed")
@ -307,7 +289,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
base_param_group = setup_param_groups(base_model)
tp_param_group = setup_param_groups(tp_model)
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
# tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
# ==============================
# Optimizer Init
@ -378,141 +360,19 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
if len(shard_spec.sharding_sequence) >= 2:
# Col Parallel
if shard_spec.sharding_sequence[0] == "R":
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
p = p.chunk(tp_size, dim=-1)[dist.get_rank(tp_group)]
# ROW Parallel
if shard_spec.sharding_sequence[-1] == "R":
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
p = p.chunk(tp_size, dim=0)[dist.get_rank(tp_group)]
else:
# TP bias
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
else:
# No TP bias
pass
correctness_verify(p.data, tp_p.data, dtype)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Zero Test Passed")
@parameterize("dtype", [torch.float16])
@parameterize("tp_zero_size", [(1, 4)])
def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
tp_size, zero_size = tp_zero_size
use_zero = True if zero_size > 1 else False
local_rank = dist.get_rank()
p = p.chunk(tp_size, dim=-1)[dist.get_rank(tp_group)]
correctness_verify(p, tp_p, dtype)
clear_layout_converter()
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1)
torch.set_default_dtype(dtype)
set_seed(42)
# ==============================
# Model Init
# ==============================
base_model = MlpModel().to(local_rank)
# tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
tp_model = copy.deepcopy(base_model).to(local_rank)
base_param_group = setup_param_groups(base_model)
tp_param_group = setup_param_groups(tp_model)
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
# ==============================
# Optimizer Init
# ==============================
base_optim = Adafactor(base_param_group)
dist_optim = DistributedAdaFactor(tp_param_group)
# Setup distributed optimizer
if zero_size > 1:
base_optim = LowLevelZeroOptimizer(
base_optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
dist_optim = LowLevelZeroOptimizer(
dist_optim,
overlap_communication=True,
initial_scale=128,
partition_grad=True,
dp_process_group=dp_group,
verbose=True,
)
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
dist_optim.optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
else:
shard_to_param = set_master_param_to_shard_param(tp_param_group)
dist_optim.setup_distributed(
tp_group=tp_group,
dp_group=dp_group,
shard_to_working_param=shard_to_param,
use_zero=use_zero,
)
# ==============================
# Booster Init
# ==============================
plugin = LowLevelZeroPlugin()
booster = Booster(plugin=plugin)
criterion = lambda x: x.mean()
tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion)
# ==============================
# Correctness Verify
# ==============================
x = torch.randn(HEIGHT, WIDTH, device=local_rank)
out = base_model(x)
out_tp = tp_model(x)
if zero_size > 1:
dist_optim.backward(out_tp.sum())
base_optim.backward(out.sum())
else:
out_tp.sum().backward()
out.sum().backward()
base_optim.step()
dist_optim.step()
base_optim.zero_grad()
dist_optim.zero_grad()
for p, tp_p in zip(base_param_group, tp_param_group):
param_is_distributed = is_distributed_tensor(tp_p)
if param_is_distributed:
shard_spec = get_sharding_spec(tp_p)
if len(shard_spec.sharding_sequence) >= 2:
# Col Parallel
if shard_spec.sharding_sequence[0] == "R":
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
# ROW Parallel
if shard_spec.sharding_sequence[-1] == "R":
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
else:
# TP bias
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
else:
# No TP bias
pass
correctness_verify(p.data, tp_p.data, dtype)
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Booster Test Passed")
print(f"Zero Test Passed")
@parameterize(
@ -532,14 +392,6 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
model_list = [
"transformers_bert",
"transformers_bert_for_pretraining",
"transformers_bert_lm_head_model",
"transformers_bert_for_masked_lm",
"transformers_bert_for_sequence_classification",
"transformers_bert_for_token_classification",
"transformers_bert_for_next_sentence",
"transformers_bert_for_mcq",
"transformers_bert_for_question_answering",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
@ -627,14 +479,6 @@ def exam_bert_test_on_hybrid_plugin(test_config):
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
"transformers_bert_for_pretraining",
"transformers_bert_lm_head_model",
"transformers_bert_for_masked_lm",
"transformers_bert_for_sequence_classification",
"transformers_bert_for_token_classification",
"transformers_bert_for_next_sentence",
"transformers_bert_for_mcq",
"transformers_bert_for_question_answering",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
@ -673,6 +517,7 @@ def exam_bert_test_on_hybrid_plugin(test_config):
# check optim states
check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Bert Model Zoo Test Passed")
@ -681,11 +526,10 @@ def exam_bert_test_on_hybrid_plugin(test_config):
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_bert_test_on_lowlevelzero_plugin()
exam_bert_test_on_hybrid_plugin()
exam_dist_adafactor_base()
exam_dist_adafactor_zero()
exam_dist_adafactor_booster()
exam_bert_test_on_lowlevelzero_plugin()
exam_bert_test_on_hybrid_plugin()
@pytest.mark.dist

@ -287,15 +287,6 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
# test_config["initial_scale"] = 1
model_list = [
"transformers_bert",
"transformers_bert_for_pretraining",
"transformers_bert_lm_head_model",
"transformers_bert_for_masked_lm",
"transformers_bert_for_sequence_classification",
"transformers_bert_for_token_classification",
"transformers_bert_for_next_sentence",
"transformers_bert_for_mcq",
"transformers_bert_for_question_answering",
"simple_mlp",
]
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
@ -389,14 +380,6 @@ def exam_bert_test_on_hybrid_plugin(test_config):
test_config["initial_scale"] = 2**16 # avoid overflow
model_list = [
"transformers_bert",
"transformers_bert_for_pretraining",
"transformers_bert_lm_head_model",
"transformers_bert_for_masked_lm",
"transformers_bert_for_sequence_classification",
"transformers_bert_for_token_classification",
"transformers_bert_for_next_sentence",
"transformers_bert_for_mcq",
"transformers_bert_for_question_answering",
]
# pass "transformers_bert",

@ -18,7 +18,6 @@ from tests.test_optimizer._utils import check_optim_states, run_bert_test
_ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
]
@ -264,7 +263,6 @@ def run_dist_lamb_fwd_bwd(
torch_optim.step()
optim.step()
dist.barrier()
torch_optim.zero_grad()
optim.zero_grad()
try:

@ -1,126 +0,0 @@
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
{"placement_policy": "static", "shard_param_frac": 1.0}, # zero3
{"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half
{"placement_policy": "auto"},
]
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
if not model.chunk_manager.reuse_fp16_chunk:
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)
for p0, p1 in zip(model.parameters(), torch_model.parameters()):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gather", [False, True])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
@parameterize("max_prefetch", [0, 4])
@parameterize("enable_async_reduce", [False, True])
def exam_gpt_fwd_bwd(
placement_config,
keep_gather,
model_name: str,
use_grad_checkpoint: bool = False,
master_weights: bool = True,
max_prefetch: int = 0,
enable_async_reduce=True,
):
init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42)
model = model_builder()
set_seed(42)
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
torch_model.gradient_checkpointing_enable()
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gather
model = GeminiDDP(
model,
config_dict,
init_device,
pin_memory=True,
**placement_config,
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
rank = dist.get_rank()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[rank])
set_seed(rank)
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
torch_optim.zero_grad()
zero_optim.zero_grad()
# set random seed is same as torch_model.eval()
set_seed(42)
torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
set_seed(42)
loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
assert_close(torch_loss.float(), loss.float())
check_grad(model, torch_model)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_gpt(1)
Loading…
Cancel
Save