mirror of https://github.com/hpcaitech/ColossalAI
parent
50ec3a7e06
commit
b3a03e4bfd
|
@ -15,7 +15,7 @@ from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.optimizer import ColoOptimizer
|
from colossalai.nn.optimizer import ColoOptimizer
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from _utils import set_seed
|
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_linear(weight):
|
def init_1d_row_linear(weight):
|
||||||
|
@ -144,20 +144,8 @@ def run_1d_hybrid_tp(model_name):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# check param
|
# check param
|
||||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
for p, torch_p in zip(model.parameters(), model_torch.parameters()):
|
||||||
if p1.size() == p2.size():
|
assert tensor_shard_equal(torch_p, p)
|
||||||
assert torch.allclose(p1, p2)
|
|
||||||
else:
|
|
||||||
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
|
|
||||||
if p1.size(-1) < p2.size(-1): # col
|
|
||||||
world_size = p2.size(-1) // p1.size(-1)
|
|
||||||
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
|
|
||||||
|
|
||||||
elif p1.size(0) < p2.size(0): # row
|
|
||||||
world_size = p2.size(0) // p1.size(0)
|
|
||||||
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
|
|
||||||
|
|
||||||
assert torch.allclose(p1, split_p2)
|
|
||||||
|
|
||||||
if i > 5:
|
if i > 5:
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue