|
|
|
@ -15,7 +15,7 @@ from colossalai.context import ParallelMode
|
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.nn.optimizer import ColoOptimizer |
|
|
|
|
from functools import partial |
|
|
|
|
from _utils import set_seed |
|
|
|
|
from _utils import tensor_equal, tensor_shard_equal, set_seed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_row_linear(weight): |
|
|
|
@ -144,20 +144,8 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
# check param |
|
|
|
|
for p1, p2 in zip(model.parameters(), model_torch.parameters()): |
|
|
|
|
if p1.size() == p2.size(): |
|
|
|
|
assert torch.allclose(p1, p2) |
|
|
|
|
else: |
|
|
|
|
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec. |
|
|
|
|
if p1.size(-1) < p2.size(-1): # col |
|
|
|
|
world_size = p2.size(-1) // p1.size(-1) |
|
|
|
|
split_p2 = torch.chunk(p2, world_size, dim=-1)[0] |
|
|
|
|
|
|
|
|
|
elif p1.size(0) < p2.size(0): # row |
|
|
|
|
world_size = p2.size(0) // p1.size(0) |
|
|
|
|
split_p2 = torch.chunk(p2, world_size, dim=0)[0] |
|
|
|
|
|
|
|
|
|
assert torch.allclose(p1, split_p2) |
|
|
|
|
for p, torch_p in zip(model.parameters(), model_torch.parameters()): |
|
|
|
|
assert tensor_shard_equal(torch_p, p) |
|
|
|
|
|
|
|
|
|
if i > 5: |
|
|
|
|
break |
|
|
|
|