mirror of https://github.com/hpcaitech/ColossalAI
[chore] minor fix
parent
404b16faf3
commit
09d6280d3e
|
@ -26,18 +26,8 @@ class MixtralPolicy(Policy):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
|
self.tie_weight = self.tie_weight_check()
|
||||||
self.origin_attn_implement = self.model.config._attn_implementation
|
self.origin_attn_implement = self.model.config._attn_implementation
|
||||||
# if self.shard_config.enable_tensor_parallelism:
|
|
||||||
# # non-moe params tensor parallelism
|
|
||||||
|
|
||||||
# # Resize embedding
|
|
||||||
# vocab_size = self.model.config.vocab_size
|
|
||||||
# world_size = self.shard_config.tensor_parallel_size
|
|
||||||
|
|
||||||
# if vocab_size % world_size != 0:
|
|
||||||
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
|
||||||
# self.model.resize_token_embeddings(new_vocab_size)
|
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
|
|
@ -67,12 +67,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||||
try:
|
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
||||||
print(f"{rank=},passed grad: {n1}, {n2}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
grads_to_check = {}
|
grads_to_check = {}
|
||||||
|
@ -108,25 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
# print(grads_to_check)
|
# print(grads_to_check)
|
||||||
check_all_grad_tensors(grads_to_check)
|
check_all_grad_tensors(grads_to_check)
|
||||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||||
try:
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
||||||
print(f"{rank=},passed param before step: {n1}, {n2}")
|
|
||||||
except Exception:
|
|
||||||
print(
|
|
||||||
f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}"
|
|
||||||
)
|
|
||||||
# optimizer executes step
|
# optimizer executes step
|
||||||
org_optimizer.step()
|
org_optimizer.step()
|
||||||
sharded_optimizer.step()
|
sharded_optimizer.step()
|
||||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||||
try:
|
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
|
||||||
print(f"{rank=},passed param after step: {n1}, {n2}")
|
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
# check weights
|
# check weights
|
||||||
if stage_manager is None or stage_manager.is_first_stage():
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
if test_config["precision"] == "fp32":
|
if test_config["precision"] == "fp32":
|
||||||
|
|
Loading…
Reference in New Issue