|
|
@ -39,6 +39,8 @@ from colossalai.nn.parallel import GeminiDDP |
|
|
|
from colossalai.utils import get_current_device |
|
|
|
from colossalai.utils import get_current_device |
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from colossalai.tensor import ProcessGroup, ShardSpec |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_data(batch_size, seq_len, vocab_size): |
|
|
|
def get_data(batch_size, seq_len, vocab_size): |
|
|
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) |
|
|
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) |
|
|
@ -102,6 +104,11 @@ def parse_args(): |
|
|
|
help="Model type to use if training from scratch.", |
|
|
|
help="Model type to use if training from scratch.", |
|
|
|
choices=MODEL_TYPES, |
|
|
|
choices=MODEL_TYPES, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
parser.add_argument( |
|
|
|
|
|
|
|
"--shardinit", |
|
|
|
|
|
|
|
action="store_true", |
|
|
|
|
|
|
|
help="Initialize the model with tensor parallel", |
|
|
|
|
|
|
|
) |
|
|
|
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") |
|
|
|
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") |
|
|
|
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") |
|
|
|
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") |
|
|
|
args = parser.parse_args() |
|
|
|
args = parser.parse_args() |
|
|
@ -159,16 +166,30 @@ def main(): |
|
|
|
else: |
|
|
|
else: |
|
|
|
init_dev = get_current_device() |
|
|
|
init_dev = get_current_device() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# shard init prameters |
|
|
|
|
|
|
|
if args.shardinit: |
|
|
|
|
|
|
|
logger.info("Sharding initialization !", ranks=[0]) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
logger.info("Skipping sharding initialization", ranks=[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
world_size = torch.distributed.get_world_size() |
|
|
|
|
|
|
|
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None |
|
|
|
|
|
|
|
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None |
|
|
|
|
|
|
|
|
|
|
|
# build model |
|
|
|
# build model |
|
|
|
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': |
|
|
|
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': |
|
|
|
# currently, there has a bug in pretrained opt-13b |
|
|
|
# currently, there has a bug in pretrained opt-13b |
|
|
|
# we can not import it until huggingface fix it |
|
|
|
# we can not import it until huggingface fix it |
|
|
|
logger.info("Train a new model from scratch", ranks=[0]) |
|
|
|
logger.info("Train a new model from scratch", ranks=[0]) |
|
|
|
with ColoInitContext(device=init_dev, dtype=torch.half): |
|
|
|
with ColoInitContext(device=init_dev, dtype=torch.half, |
|
|
|
|
|
|
|
default_dist_spec=default_dist_spec, |
|
|
|
|
|
|
|
default_pg=shard_pg): |
|
|
|
model = OPTForCausalLM(config) |
|
|
|
model = OPTForCausalLM(config) |
|
|
|
else: |
|
|
|
else: |
|
|
|
logger.info("Finetune a pre-trained model", ranks=[0]) |
|
|
|
logger.info("Finetune a pre-trained model", ranks=[0]) |
|
|
|
with ColoInitContext(device=init_dev, dtype=torch.half): |
|
|
|
with ColoInitContext(device=init_dev, dtype=torch.half, |
|
|
|
|
|
|
|
default_dist_spec=default_dist_spec, |
|
|
|
|
|
|
|
default_pg=shard_pg): |
|
|
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, |
|
|
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, |
|
|
|
from_tf=bool(".ckpt" in args.model_name_or_path), |
|
|
|
from_tf=bool(".ckpt" in args.model_name_or_path), |
|
|
|
config=config, |
|
|
|
config=config, |
|
|
@ -179,7 +200,8 @@ def main(): |
|
|
|
|
|
|
|
|
|
|
|
numel = sum([p.numel() for p in model.parameters()]) |
|
|
|
numel = sum([p.numel() for p in model.parameters()]) |
|
|
|
PLACEMENT_POLICY = 'cpu' |
|
|
|
PLACEMENT_POLICY = 'cpu' |
|
|
|
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) |
|
|
|
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, |
|
|
|
|
|
|
|
pin_memory=True, strict_ddp_mode=args.shardinit) |
|
|
|
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) |
|
|
|
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) |
|
|
|
|
|
|
|
|
|
|
|
SEQ_LEN = 1024 |
|
|
|
SEQ_LEN = 1024 |
|
|
|