mirror of https://github.com/hpcaitech/ColossalAI
[chore] minor fix
parent
404b16faf3
commit
09d6280d3e
|
@ -26,18 +26,8 @@ class MixtralPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
self.origin_attn_implement = self.model.config._attn_implementation
|
||||
# if self.shard_config.enable_tensor_parallelism:
|
||||
# # non-moe params tensor parallelism
|
||||
|
||||
# # Resize embedding
|
||||
# vocab_size = self.model.config.vocab_size
|
||||
# world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
# if vocab_size % world_size != 0:
|
||||
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
# self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
|
|
|
@ -67,12 +67,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
rank = dist.get_rank()
|
||||
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
try:
|
||||
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
print(f"{rank=},passed grad: {n1}, {n2}")
|
||||
except Exception as e:
|
||||
print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}")
|
||||
raise e
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
|
@ -108,25 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# print(grads_to_check)
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
try:
|
||||
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
print(f"{rank=},passed param before step: {n1}, {n2}")
|
||||
except Exception:
|
||||
print(
|
||||
f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}"
|
||||
)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
try:
|
||||
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
print(f"{rank=},passed param after step: {n1}, {n2}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}"
|
||||
)
|
||||
raise e
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config["precision"] == "fp32":
|
||||
|
|
Loading…
Reference in New Issue