Browse Source

[chore] minor fix

colossalchat
hxwang 4 months ago committed by Hongxin Liu
parent
commit
09d6280d3e
  1. 12
      colossalai/shardformer/policies/mixtral.py
  2. 26
      tests/test_shardformer/test_model/test_shard_mixtral.py

12
colossalai/shardformer/policies/mixtral.py

@ -26,18 +26,8 @@ class MixtralPolicy(Policy):
pass
def preprocess(self):
self.tie_weight = self.tie_weight_check()
self.origin_attn_implement = self.model.config._attn_implementation
# if self.shard_config.enable_tensor_parallelism:
# # non-moe params tensor parallelism
# # Resize embedding
# vocab_size = self.model.config.vocab_size
# world_size = self.shard_config.tensor_parallel_size
# if vocab_size % world_size != 0:
# new_vocab_size = vocab_size + world_size - vocab_size % world_size
# self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

26
tests/test_shardformer/test_model/test_shard_mixtral.py

@ -67,12 +67,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rank = dist.get_rank()
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
try:
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
print(f"{rank=},passed grad: {n1}, {n2}")
except Exception as e:
print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}")
raise e
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
@ -108,25 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# print(grads_to_check)
check_all_grad_tensors(grads_to_check)
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
try:
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
print(f"{rank=},passed param before step: {n1}, {n2}")
except Exception:
print(
f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}"
)
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
try:
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
print(f"{rank=},passed param after step: {n1}, {n2}")
except Exception as e:
print(
f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}"
)
raise e
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":

Loading…
Cancel
Save