mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix typo of openmoe model source (#5403)
parent
e304e4db35
commit
e239cf9060
|
@ -207,7 +207,7 @@ def main():
|
|||
coordinator.print_on_master(f"Set plugin as {plugin}")
|
||||
|
||||
# Build OpenMoe model
|
||||
repo_name = "hpcaitech/openmoe-" + args.model_name
|
||||
repo_name = "hpcai-tech/openmoe-" + args.model_name
|
||||
config = LlamaConfig.from_pretrained(repo_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
|
|
|
@ -53,7 +53,7 @@ def fsdp_main(rank, world_size, args):
|
|||
train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
|
||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-%s" % args.model_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
num_experts=config.num_experts,
|
||||
|
|
|
@ -15,19 +15,19 @@ def parse_args():
|
|||
def inference(args):
|
||||
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
|
||||
if args.model == "test":
|
||||
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
|
||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
|
||||
set_openmoe_args(config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_kernel=True)
|
||||
model = OpenMoeForCausalLM(config)
|
||||
else:
|
||||
config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}")
|
||||
config = LlamaConfig.from_pretrained(f"hpcai-tech/openmoe-{args.model}")
|
||||
set_openmoe_args(config,
|
||||
num_experts=config.num_experts,
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_kernel=False)
|
||||
model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config)
|
||||
model = OpenMoeForCausalLM.from_pretrained(f"hpcai-tech/openmoe-{args.model}", config=config)
|
||||
model = model.eval().bfloat16()
|
||||
model = model.to(torch.cuda.current_device())
|
||||
|
||||
|
|
|
@ -269,12 +269,12 @@ def main():
|
|||
|
||||
# Build OpenMoe model
|
||||
if test_mode:
|
||||
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
|
||||
config = LlamaConfig.from_pretrained("hpcai-tech/openmoe-base")
|
||||
config.hidden_size = 128
|
||||
config.intermediate_size = 256
|
||||
config.vocab_size = 32000
|
||||
else:
|
||||
repo_name = "hpcaitech/openmoe-" + args.model_name
|
||||
repo_name = "hpcai-tech/openmoe-" + args.model_name
|
||||
config = LlamaConfig.from_pretrained(repo_name)
|
||||
set_openmoe_args(
|
||||
config,
|
||||
|
|
Loading…
Reference in New Issue