|
|
@ -39,6 +39,8 @@ from colossalai.nn.parallel import GeminiDDP
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from colossalai.tensor import ProcessGroup, ShardSpec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_data(batch_size, seq_len, vocab_size):
|
|
|
|
def get_data(batch_size, seq_len, vocab_size):
|
|
|
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
|
|
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
|
|
@ -102,6 +104,11 @@ def parse_args():
|
|
|
|
help="Model type to use if training from scratch.",
|
|
|
|
help="Model type to use if training from scratch.",
|
|
|
|
choices=MODEL_TYPES,
|
|
|
|
choices=MODEL_TYPES,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
|
|
"--shardinit",
|
|
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
|
|
help="Initialize the model with tensor parallel",
|
|
|
|
|
|
|
|
)
|
|
|
|
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
|
|
|
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
|
|
|
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
|
|
|
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
@ -159,16 +166,30 @@ def main():
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
init_dev = get_current_device()
|
|
|
|
init_dev = get_current_device()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# shard init prameters
|
|
|
|
|
|
|
|
if args.shardinit:
|
|
|
|
|
|
|
|
logger.info("Sharding initialization !", ranks=[0])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
logger.info("Skipping sharding initialization", ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
|
|
|
|
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
|
|
|
|
|
|
|
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
|
|
|
|
|
|
|
|
|
|
|
# build model
|
|
|
|
# build model
|
|
|
|
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
|
|
|
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
|
|
|
# currently, there has a bug in pretrained opt-13b
|
|
|
|
# currently, there has a bug in pretrained opt-13b
|
|
|
|
# we can not import it until huggingface fix it
|
|
|
|
# we can not import it until huggingface fix it
|
|
|
|
logger.info("Train a new model from scratch", ranks=[0])
|
|
|
|
logger.info("Train a new model from scratch", ranks=[0])
|
|
|
|
with ColoInitContext(device=init_dev, dtype=torch.half):
|
|
|
|
with ColoInitContext(device=init_dev, dtype=torch.half,
|
|
|
|
|
|
|
|
default_dist_spec=default_dist_spec,
|
|
|
|
|
|
|
|
default_pg=shard_pg):
|
|
|
|
model = OPTForCausalLM(config)
|
|
|
|
model = OPTForCausalLM(config)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
logger.info("Finetune a pre-trained model", ranks=[0])
|
|
|
|
logger.info("Finetune a pre-trained model", ranks=[0])
|
|
|
|
with ColoInitContext(device=init_dev, dtype=torch.half):
|
|
|
|
with ColoInitContext(device=init_dev, dtype=torch.half,
|
|
|
|
|
|
|
|
default_dist_spec=default_dist_spec,
|
|
|
|
|
|
|
|
default_pg=shard_pg):
|
|
|
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
|
|
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
|
|
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
|
|
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
|
|
|
config=config,
|
|
|
|
config=config,
|
|
|
@ -179,7 +200,8 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
numel = sum([p.numel() for p in model.parameters()])
|
|
|
|
numel = sum([p.numel() for p in model.parameters()])
|
|
|
|
PLACEMENT_POLICY = 'cpu'
|
|
|
|
PLACEMENT_POLICY = 'cpu'
|
|
|
|
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
|
|
|
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
|
|
|
|
|
|
|
|
pin_memory=True, strict_ddp_mode=args.shardinit)
|
|
|
|
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
|
|
|
|
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
|
|
|
|
|
|
|
|
|
|
|
|
SEQ_LEN = 1024
|
|
|
|
SEQ_LEN = 1024
|
|
|
|