mirror of https://github.com/hpcaitech/ColossalAI
[misc] Accelerate CI for zero and dist optim (#5758)
* remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>pull/5782/head
parent
50b4c8e8cf
commit
79f7a7b211
|
@ -22,8 +22,6 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1
|
|||
b,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
|
||||
dtype: {a.dtype} vs {b.dtype}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
|||
dist.barrier()
|
||||
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
|
@ -130,13 +130,11 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
|
||||
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
|
||||
)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(
|
||||
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
|
||||
)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
|
||||
for group in new_optimizer.param_groups:
|
||||
assert group["lr"] == 0.1
|
||||
|
||||
|
@ -169,7 +167,7 @@ def exam_lazy_from_pretrained():
|
|||
booster.save_model(model, save_path, shard=False)
|
||||
dist.barrier()
|
||||
state_dict = torch.load(save_path, map_location="cpu")
|
||||
check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
|
||||
check_state_dict_equal(state_dict, orig_state_dict, ignore_dtype=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -62,12 +62,12 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
|||
check_state_dict_equal(
|
||||
model.state_dict(only_rank_0=False, prefix="module.module."),
|
||||
new_model.state_dict(),
|
||||
False,
|
||||
ignore_device=False,
|
||||
ignore_dtype=True,
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), ignore_device=False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
@ -128,7 +128,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
|||
check_state_dict_equal(
|
||||
new_model.state_dict(only_rank_0=False, prefix="module.module."),
|
||||
model.state_dict(),
|
||||
False,
|
||||
ignore_device=False,
|
||||
ignore_dtype=True,
|
||||
)
|
||||
|
||||
|
@ -145,7 +145,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
|
|||
k in old_group and k in new_group
|
||||
), f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
|
||||
assert old_group[k] == new_group[k]
|
||||
check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], False)
|
||||
check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], ignore_device=False)
|
||||
|
||||
# Check the new model/optimizer can successfully run.
|
||||
data = data_gen_fn()
|
||||
|
|
|
@ -94,9 +94,9 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
|
||||
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict())
|
||||
dist.barrier()
|
||||
|
||||
# Check whether the loaded model & optimizer works smoothly.
|
||||
|
|
|
@ -55,7 +55,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
|||
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
# check master weight
|
||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||
working_param_id_set = set(id(p) for p in new_model.parameters())
|
||||
|
@ -70,7 +70,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
|||
)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@ -110,7 +110,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
|||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
|
||||
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
|
||||
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
|
||||
# check master weight
|
||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||
|
@ -126,7 +126,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
|||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
||||
|
||||
except Exception as e:
|
||||
# return repr(e)
|
||||
|
|
|
@ -61,9 +61,9 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
|
|||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
if plugin_type == "gemini":
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False))
|
||||
else:
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
|
||||
dist.barrier()
|
||||
|
||||
|
||||
|
|
|
@ -52,12 +52,12 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
|
|||
)
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict())
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -3,7 +3,6 @@ import torch.distributed as dist
|
|||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer._operation import _gather
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import parameterize, spawn
|
||||
|
@ -119,11 +118,15 @@ def run_bert_test(test_config, optim_class, sharded_optim_class):
|
|||
test_config["use_lazy_init"] = False
|
||||
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
|
||||
test_config["initial_scale"] = 2**15 # avoid overflow
|
||||
target_models = [
|
||||
"transformers_bert",
|
||||
]
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_bert_fwd_bwd(
|
||||
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
|
||||
)
|
||||
if name in target_models:
|
||||
check_bert_fwd_bwd(
|
||||
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
|
||||
)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
|
@ -152,7 +155,8 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
|
|||
shard_spec = sharded_optimizer.shard_spec_dict[id(tp)]
|
||||
use_zero = sharded_optimizer.use_zero
|
||||
tp_optim_state = tp_state[key]
|
||||
p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape
|
||||
state = p_state[key]
|
||||
|
||||
dp_size, tp_size = (
|
||||
sharded_optimizer.dp_size,
|
||||
sharded_optimizer.tp_size,
|
||||
|
@ -165,88 +169,54 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
|
|||
if shard_spec.sharding_sequence[0] == "R":
|
||||
if use_zero:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
if key == "exp_avg_sq_row":
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
|
||||
# gather from tp group
|
||||
# sq_row don need gather alone tp group
|
||||
if key == "exp_avg_sq_row":
|
||||
pass
|
||||
# sq_col need gather alone dp group
|
||||
# sq_col need gather alone tp group
|
||||
if key == "exp_avg_sq_col":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
|
||||
)
|
||||
tp_optim_state.shape
|
||||
|
||||
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
|
||||
# row parallel
|
||||
if shard_spec.sharding_sequence[-1] == "R":
|
||||
if use_zero:
|
||||
elif shard_spec.sharding_sequence[-1] == "R":
|
||||
# TODO: this case may cause shape mismatch @duanjunwen
|
||||
if use_zero and key == "exp_avg_sq_row" and state.shape[0] // tp_size % dp_size == 0:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
if p_state[key].shape[0] // tp_size % dp_size != 0:
|
||||
pass
|
||||
else:
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
|
||||
# gather from tp group
|
||||
# sq_row need gather alone tp group
|
||||
if key == "exp_avg_sq_row":
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
|
||||
)
|
||||
tp_optim_state.shape
|
||||
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
pass
|
||||
else:
|
||||
return
|
||||
else:
|
||||
if use_zero:
|
||||
# sq_row need gather alone dp group
|
||||
if key == "exp_avg_sq_row":
|
||||
# row residule; no gather
|
||||
if p_state[key].shape[0] % dp_size != 0:
|
||||
if state.shape[0] % dp_size != 0:
|
||||
pass
|
||||
else:
|
||||
tp_optim_state = _gather(
|
||||
input_=tp_optim_state,
|
||||
dim=-1,
|
||||
process_group=sharded_optimizer.dp_group,
|
||||
)
|
||||
tp_optim_state.shape
|
||||
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
|
||||
# sq_col don't need gather alone dp group
|
||||
if key == "exp_avg_sq_col":
|
||||
tp_optim_state = tp_optim_state.div_(dp_size)
|
||||
# need a div;
|
||||
else:
|
||||
pass
|
||||
# Sovled a New issus: different dtype;
|
||||
# So far, only happen in H100 env;
|
||||
# Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision;
|
||||
# Or assert_close just update to check dtype;
|
||||
if p_state[key].dtype != tp_optim_state.dtype:
|
||||
tp_optim_state = tp_optim_state.type(p_state[key].dtype)
|
||||
try:
|
||||
assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2)
|
||||
except:
|
||||
pass
|
||||
|
||||
if state.dtype != tp_optim_state.dtype:
|
||||
tp_optim_state = tp_optim_state.type(state.dtype)
|
||||
# TODO: some sharding checks are currently buggy, but the state values should match
|
||||
# @duanjunwen
|
||||
if state.shape != tp_optim_state.shape:
|
||||
return
|
||||
assert_close(state, tp_optim_state, atol=5e-4, rtol=1.6e-2)
|
||||
|
||||
|
||||
def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol):
|
||||
|
|
|
@ -7,14 +7,11 @@ from torch import nn
|
|||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.nn.optimizer.adafactor import Adafactor
|
||||
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer._operation import _gather
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
|
@ -59,7 +56,6 @@ def correctness_verify(tensor1: torch.Tensor, tensor2: torch.Tensor, dtype: torc
|
|||
rtol = 4e-3
|
||||
atol = 4e-3
|
||||
|
||||
# return torch.all(tensor1.isclose(tensor2, rtol=rtol, atol=atol))
|
||||
assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
|
@ -194,7 +190,6 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||
# Col Parallel
|
||||
# ==============================
|
||||
weight_col_shard = shard_colwise(weight.clone(), tp_group)
|
||||
weight_col_shard_layout = get_layout(weight_col_shard) # Layout info weight_col_shard_layout.global_shape
|
||||
weight_col_shard_shard_spec = get_sharding_spec(weight_col_shard) # Shard spec
|
||||
weight_col_shard_flatten = nn.Parameter(weight_col_shard.clone().flatten().requires_grad_(True))
|
||||
bias_col_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True))
|
||||
|
@ -203,17 +198,12 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||
# Row Parallel
|
||||
# ==============================
|
||||
weight_row_shard = shard_rowwise(weight.clone(), tp_group)
|
||||
weight_row_shard_layout = get_layout(weight_row_shard) # Layout info weight_row_shard_layout.global_shape
|
||||
weight_row_shard_shard_spec = get_sharding_spec(weight_row_shard) # Shard spec
|
||||
weight_row_shard_flatten = nn.Parameter(
|
||||
weight_row_shard.clone().flatten().requires_grad_(True)
|
||||
) # flatten input(not dtensor) to optimizer
|
||||
bias_row_flatten = nn.Parameter(bias.clone().flatten().requires_grad_(True))
|
||||
|
||||
# base_param_group = setup_param_groups([weight, bias])
|
||||
# cp_param_group = setup_param_groups([weight_col_shard_flatten, bias_col_flatten])
|
||||
# rp_param_group = setup_param_groups([weight_row_shard_flatten, bias_row_flatten])
|
||||
|
||||
# ==============================
|
||||
# Init Optimizer
|
||||
# ==============================
|
||||
|
@ -267,19 +257,11 @@ def exam_dist_adafactor_base(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||
bias_row_flatten.grad = bias.grad.clone().flatten()
|
||||
rp_dist_optim.step()
|
||||
|
||||
# gather result
|
||||
weight_col_gather = _gather(
|
||||
input_=weight_col_shard_flatten.data.view(-1, H // tp_size),
|
||||
dim=-1,
|
||||
process_group=tp_group,
|
||||
) # gather
|
||||
weight_row_gather = _gather(input_=weight_row_shard_flatten.data, dim=-1, process_group=tp_group).view(
|
||||
-1, W
|
||||
) # gather
|
||||
|
||||
weight_row_chunk = weight.t().reshape(-1, W).chunk(tp_size, dim=-1)[dist.get_rank(tp_group)].flatten()
|
||||
weight_col_chunk = weight.reshape(-1, H).chunk(tp_size, dim=-1)[dist.get_rank(tp_group)].flatten()
|
||||
# verify
|
||||
correctness_verify(weight.data, weight_col_gather.data, dtype)
|
||||
correctness_verify(weight.data, weight_row_gather.data, dtype)
|
||||
correctness_verify(weight_col_chunk, weight_col_shard_flatten, dtype)
|
||||
correctness_verify(weight_row_chunk, weight_row_shard_flatten, dtype)
|
||||
|
||||
print(f"Base Test Passed")
|
||||
|
||||
|
@ -307,7 +289,7 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||
|
||||
base_param_group = setup_param_groups(base_model)
|
||||
tp_param_group = setup_param_groups(tp_model)
|
||||
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
|
||||
# tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
|
||||
|
||||
# ==============================
|
||||
# Optimizer Init
|
||||
|
@ -378,143 +360,21 @@ def exam_dist_adafactor_zero(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
|||
if len(shard_spec.sharding_sequence) >= 2:
|
||||
# Col Parallel
|
||||
if shard_spec.sharding_sequence[0] == "R":
|
||||
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
|
||||
p = p.chunk(tp_size, dim=-1)[dist.get_rank(tp_group)]
|
||||
# ROW Parallel
|
||||
if shard_spec.sharding_sequence[-1] == "R":
|
||||
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
|
||||
p = p.chunk(tp_size, dim=0)[dist.get_rank(tp_group)]
|
||||
else:
|
||||
# TP bias
|
||||
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
|
||||
else:
|
||||
# No TP bias
|
||||
pass
|
||||
correctness_verify(p.data, tp_p.data, dtype)
|
||||
p = p.chunk(tp_size, dim=-1)[dist.get_rank(tp_group)]
|
||||
|
||||
correctness_verify(p, tp_p, dtype)
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Zero Test Passed")
|
||||
|
||||
|
||||
@parameterize("dtype", [torch.float16])
|
||||
@parameterize("tp_zero_size", [(1, 4)])
|
||||
def exam_dist_adafactor_booster(dtype: torch.dtype, tp_zero_size: tuple[int, int]):
|
||||
tp_size, zero_size = tp_zero_size
|
||||
use_zero = True if zero_size > 1 else False
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
clear_layout_converter()
|
||||
|
||||
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
|
||||
tp_group, dp_group = proc_mesh.get_group_along_axis(0), proc_mesh.get_group_along_axis(1)
|
||||
|
||||
torch.set_default_dtype(dtype)
|
||||
set_seed(42)
|
||||
|
||||
# ==============================
|
||||
# Model Init
|
||||
# ==============================
|
||||
base_model = MlpModel().to(local_rank)
|
||||
# tp_model = TPModel(copy.deepcopy(base_model.linear1), copy.deepcopy(base_model.linear2), tp_group).to(local_rank)
|
||||
tp_model = copy.deepcopy(base_model).to(local_rank)
|
||||
|
||||
base_param_group = setup_param_groups(base_model)
|
||||
tp_param_group = setup_param_groups(tp_model)
|
||||
tp_param_group_, tp_shard_spec, tp_param_shape = setup_flatten_param_groups_sharding_spec_shape(tp_model)
|
||||
|
||||
# ==============================
|
||||
# Optimizer Init
|
||||
# ==============================
|
||||
base_optim = Adafactor(base_param_group)
|
||||
dist_optim = DistributedAdaFactor(tp_param_group)
|
||||
|
||||
# Setup distributed optimizer
|
||||
if zero_size > 1:
|
||||
base_optim = LowLevelZeroOptimizer(
|
||||
base_optim,
|
||||
overlap_communication=True,
|
||||
initial_scale=128,
|
||||
partition_grad=True,
|
||||
dp_process_group=dp_group,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
dist_optim = LowLevelZeroOptimizer(
|
||||
dist_optim,
|
||||
overlap_communication=True,
|
||||
initial_scale=128,
|
||||
partition_grad=True,
|
||||
dp_process_group=dp_group,
|
||||
verbose=True,
|
||||
)
|
||||
shard_to_param = dist_optim._param_store.master_to_working_param # {id(): param tensor} but flattened
|
||||
dist_optim.optim.setup_distributed(
|
||||
tp_group=tp_group,
|
||||
dp_group=dp_group,
|
||||
shard_to_working_param=shard_to_param,
|
||||
use_zero=use_zero,
|
||||
)
|
||||
else:
|
||||
shard_to_param = set_master_param_to_shard_param(tp_param_group)
|
||||
dist_optim.setup_distributed(
|
||||
tp_group=tp_group,
|
||||
dp_group=dp_group,
|
||||
shard_to_working_param=shard_to_param,
|
||||
use_zero=use_zero,
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Booster Init
|
||||
# ==============================
|
||||
plugin = LowLevelZeroPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
criterion = lambda x: x.mean()
|
||||
|
||||
tp_model, dist_optim, criterion, _, _ = booster.boost(tp_model, dist_optim, criterion)
|
||||
|
||||
# ==============================
|
||||
# Correctness Verify
|
||||
# ==============================
|
||||
x = torch.randn(HEIGHT, WIDTH, device=local_rank)
|
||||
|
||||
out = base_model(x)
|
||||
out_tp = tp_model(x)
|
||||
|
||||
if zero_size > 1:
|
||||
dist_optim.backward(out_tp.sum())
|
||||
base_optim.backward(out.sum())
|
||||
else:
|
||||
out_tp.sum().backward()
|
||||
out.sum().backward()
|
||||
|
||||
base_optim.step()
|
||||
dist_optim.step()
|
||||
|
||||
base_optim.zero_grad()
|
||||
dist_optim.zero_grad()
|
||||
|
||||
for p, tp_p in zip(base_param_group, tp_param_group):
|
||||
param_is_distributed = is_distributed_tensor(tp_p)
|
||||
if param_is_distributed:
|
||||
shard_spec = get_sharding_spec(tp_p)
|
||||
if len(shard_spec.sharding_sequence) >= 2:
|
||||
# Col Parallel
|
||||
if shard_spec.sharding_sequence[0] == "R":
|
||||
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
|
||||
# ROW Parallel
|
||||
if shard_spec.sharding_sequence[-1] == "R":
|
||||
tp_p = _gather(input_=tp_p, dim=0, process_group=tp_group) # gather
|
||||
else:
|
||||
# TP bias
|
||||
tp_p = _gather(input_=tp_p, dim=-1, process_group=tp_group) # gather
|
||||
else:
|
||||
# No TP bias
|
||||
pass
|
||||
correctness_verify(p.data, tp_p.data, dtype)
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Booster Test Passed")
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
|
@ -532,14 +392,6 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
|
|||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
"transformers_bert_for_pretraining",
|
||||
"transformers_bert_lm_head_model",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bert_for_sequence_classification",
|
||||
"transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_next_sentence",
|
||||
"transformers_bert_for_mcq",
|
||||
"transformers_bert_for_question_answering",
|
||||
]
|
||||
clear_layout_converter()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
@ -627,14 +479,6 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
|||
test_config["initial_scale"] = 2**16 # avoid overflow
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
"transformers_bert_for_pretraining",
|
||||
"transformers_bert_lm_head_model",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bert_for_sequence_classification",
|
||||
"transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_next_sentence",
|
||||
"transformers_bert_for_mcq",
|
||||
"transformers_bert_for_question_answering",
|
||||
]
|
||||
clear_layout_converter()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
@ -673,6 +517,7 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
|||
# check optim states
|
||||
check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Bert Model Zoo Test Passed")
|
||||
|
@ -681,11 +526,10 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
|||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
exam_bert_test_on_lowlevelzero_plugin()
|
||||
exam_bert_test_on_hybrid_plugin()
|
||||
exam_dist_adafactor_base()
|
||||
exam_dist_adafactor_zero()
|
||||
exam_dist_adafactor_booster()
|
||||
exam_bert_test_on_lowlevelzero_plugin()
|
||||
exam_bert_test_on_hybrid_plugin()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -287,15 +287,6 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
|
|||
# test_config["initial_scale"] = 1
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
"transformers_bert_for_pretraining",
|
||||
"transformers_bert_lm_head_model",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bert_for_sequence_classification",
|
||||
"transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_next_sentence",
|
||||
"transformers_bert_for_mcq",
|
||||
"transformers_bert_for_question_answering",
|
||||
"simple_mlp",
|
||||
]
|
||||
clear_layout_converter()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
@ -389,14 +380,6 @@ def exam_bert_test_on_hybrid_plugin(test_config):
|
|||
test_config["initial_scale"] = 2**16 # avoid overflow
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
"transformers_bert_for_pretraining",
|
||||
"transformers_bert_lm_head_model",
|
||||
"transformers_bert_for_masked_lm",
|
||||
"transformers_bert_for_sequence_classification",
|
||||
"transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_next_sentence",
|
||||
"transformers_bert_for_mcq",
|
||||
"transformers_bert_for_question_answering",
|
||||
]
|
||||
|
||||
# pass "transformers_bert",
|
||||
|
|
|
@ -18,7 +18,6 @@ from tests.test_optimizer._utils import check_optim_states, run_bert_test
|
|||
|
||||
_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.float), # pure fp32
|
||||
(torch.float, torch.half), # fp16 amp
|
||||
(torch.float, torch.bfloat16), # bfloat16 amp
|
||||
]
|
||||
|
||||
|
@ -264,7 +263,6 @@ def run_dist_lamb_fwd_bwd(
|
|||
|
||||
torch_optim.step()
|
||||
optim.step()
|
||||
dist.barrier()
|
||||
torch_optim.zero_grad()
|
||||
optim.zero_grad()
|
||||
try:
|
||||
|
|
|
@ -1,126 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.amp import convert_to_apex_amp
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import set_seed
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.chunk import search_chunk_configuration
|
||||
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
|
||||
|
||||
PLACEMENT_CONFIGS = [
|
||||
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
|
||||
{"placement_policy": "static", "shard_param_frac": 1.0}, # zero3
|
||||
{"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half
|
||||
{"placement_policy": "auto"},
|
||||
]
|
||||
|
||||
|
||||
def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
chunk_manager = model.chunk_manager
|
||||
param_list = [p for p in model.parameters()]
|
||||
chunk_list = chunk_manager.get_chunks(param_list)
|
||||
if not model.chunk_manager.reuse_fp16_chunk:
|
||||
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
|
||||
for chunk in chunk_list:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
|
||||
for p0, p1 in zip(model.parameters(), torch_model.parameters()):
|
||||
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
|
||||
|
||||
|
||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||
@parameterize("keep_gather", [False, True])
|
||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||
@parameterize("use_grad_checkpoint", [False, True])
|
||||
@parameterize("master_weights", [False, True])
|
||||
@parameterize("max_prefetch", [0, 4])
|
||||
@parameterize("enable_async_reduce", [False, True])
|
||||
def exam_gpt_fwd_bwd(
|
||||
placement_config,
|
||||
keep_gather,
|
||||
model_name: str,
|
||||
use_grad_checkpoint: bool = False,
|
||||
master_weights: bool = True,
|
||||
max_prefetch: int = 0,
|
||||
enable_async_reduce=True,
|
||||
):
|
||||
init_device = get_accelerator().get_current_device()
|
||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||
iter(model_zoo.get_sub_registry(model_name).values())
|
||||
)
|
||||
|
||||
set_seed(42)
|
||||
model = model_builder()
|
||||
|
||||
set_seed(42)
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
if use_grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
torch_model.gradient_checkpointing_enable()
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = keep_gather
|
||||
model = GeminiDDP(
|
||||
model,
|
||||
config_dict,
|
||||
init_device,
|
||||
pin_memory=True,
|
||||
**placement_config,
|
||||
master_weights=master_weights,
|
||||
max_prefetch=max_prefetch,
|
||||
enable_async_reduce=enable_async_reduce,
|
||||
)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
rank = dist.get_rank()
|
||||
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[rank])
|
||||
|
||||
set_seed(rank)
|
||||
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
|
||||
|
||||
torch_optim.zero_grad()
|
||||
zero_optim.zero_grad()
|
||||
|
||||
# set random seed is same as torch_model.eval()
|
||||
set_seed(42)
|
||||
torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
|
||||
set_seed(42)
|
||||
loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
|
||||
|
||||
assert_close(torch_loss.float(), loss.float())
|
||||
|
||||
check_grad(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
exam_gpt_fwd_bwd()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt(1)
|
Loading…
Reference in New Issue