diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index f9725043e..79cddeb7b 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -172,7 +172,8 @@ def initialize_model(model: nn.Module, memory_budget: float = -1.0, save_solver_solution: bool = False, load_solver_solution: bool = False, - solution_path: str = None): + solution_path: str = None, + return_solution: bool = False): ''' This method is used to initialize the sharded model which could be used as normal pytorch model. @@ -187,6 +188,9 @@ def initialize_model(model: nn.Module, load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded from the solution_path. solution_path(optional): the path to save or load the solution. + return_solution(optional): if the return_solution is True, the solution will be returned. The returned + solution will be used to debug or help to analyze the sharding result. Therefore, we will not just + return a series of integers, but return the best strategies. ''' tracer = ColoTracer() @@ -204,7 +208,14 @@ def initialize_model(model: nn.Module, gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) - return model_to_return + if return_solution: + solution_to_return = [] + nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] + for index, node in enumerate(nodes): + solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}') + return model_to_return, solution_to_return + else: + return model_to_return def autoparallelize(model: nn.Module, @@ -216,6 +227,7 @@ def autoparallelize(model: nn.Module, save_solver_solution: bool = False, load_solver_solution: bool = False, solver_solution_path: str = None, + return_solution: bool = False, memory_budget: float = -1.0): ''' This method is used to initialize the device mesh, extract the meta_args, and @@ -238,18 +250,26 @@ def autoparallelize(model: nn.Module, load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded from the solution_path. solver_solution_path(optional): the path to save or load the solution. + return_solution(optional): if the return_solution is True, the solution will be returned. memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, the memory budget will be infinity. ''' device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape) if meta_args is None: meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) - model = initialize_model(model, - meta_args, - device_mesh, - save_solver_solution=save_solver_solution, - load_solver_solution=load_solver_solution, - solver_solution_path=solver_solution_path, - memory_budget=memory_budget) - - return model + + rst_to_unpack = initialize_model(model, + meta_args, + device_mesh, + save_solver_solution=save_solver_solution, + load_solver_solution=load_solver_solution, + solver_solution_path=solver_solution_path, + return_solution=return_solution, + memory_budget=memory_budget) + + if return_solution: + model, solution = rst_to_unpack + return model, solution + else: + model = rst_to_unpack + return model diff --git a/examples/language/gpt/auto_parallel_with_gpt/README.md b/examples/language/gpt/auto_parallel_with_gpt/README.md new file mode 100644 index 000000000..2c24d3b53 --- /dev/null +++ b/examples/language/gpt/auto_parallel_with_gpt/README.md @@ -0,0 +1,44 @@ +# Auto-Parallelism with GPT2 + +## Requirements + +Before you can launch training, you need to install the following requirements. + +### Install PyTorch + +```bash +#conda +conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch +#pip +pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +### Install [Colossal-AI v0.1.12](https://colossalai.org/download/) From Official Website + +```bash +pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org +``` + +### Install transformers + +```bash +pip install transformers +``` + +### Install pulp and coin-or-cbc + +```bash +pip install pulp +conda install -c conda-forge coin-or-cbc +``` + +## Dataset + +For simplicity, the input data is randonly generated here. + +## Training + +```bash +#Run the auto parallel resnet example with 4 GPUs with a dummy dataset. +colossalai run --nproc_per_node 4 auto_parallel_with_gpt.py +``` diff --git a/examples/language/gpt/auto_parallel_with_gpt/auto_parallel_with_gpt.py b/examples/language/gpt/auto_parallel_with_gpt/auto_parallel_with_gpt.py new file mode 100644 index 000000000..85c8d64d7 --- /dev/null +++ b/examples/language/gpt/auto_parallel_with_gpt/auto_parallel_with_gpt.py @@ -0,0 +1,109 @@ +from functools import partial +from time import time +from typing import Dict, Optional, Tuple, Union + +import psutil +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import transformers +from gpt_modules import GPT2LMHeadModel, GPTLMLoss +from torch.fx import GraphModule + +from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch_from_torch +from colossalai.logging import disable_existing_loggers, get_dist_logger + +BATCH_SIZE = 8 +SEQ_LENGTH = 128 +HIDDEN_DIM = 3072 +NUM_HEADS = 16 +NUM_LAYERS = 1 +VOCAB_SIZE = 50257 +NUM_STEPS = 10 +FP16 = False + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4 + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def main(): + disable_existing_loggers() + launch_from_torch(config={}) + logger = get_dist_logger() + config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) + if FP16: + model = GPT2LMHeadModel(config=config).half().to('cuda') + else: + model = GPT2LMHeadModel(config=config).to('cuda') + global_numel = sum([p.numel() for p in model.parameters()]) + + meta_input_sample = { + 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + } + + # Both device mesh initialization and model initialization will be integrated into autoparallelize + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # Enable auto-parallel + gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True) + + # print solution on rank 0 + if gpc.get_global_rank() == 0: + for node_strategy in solution: + print(node_strategy) + + # build criterion + criterion = GPTLMLoss() + + optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH) + torch.cuda.synchronize() + model.train() + + for n in range(10): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE) + optimizer.zero_grad() + start = time() + outputs = gm(input_ids, attn_mask) + loss = criterion(outputs, input_ids) + loss.backward() + optimizer.step() + torch.cuda.synchronize() + step_time = time() - start + logger.info( + f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', + ranks=[0]) + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/examples/language/gpt/auto_parallel_with_gpt/gpt_modules.py b/examples/language/gpt/auto_parallel_with_gpt/gpt_modules.py new file mode 100644 index 000000000..95feaec38 --- /dev/null +++ b/examples/language/gpt/auto_parallel_with_gpt/gpt_modules.py @@ -0,0 +1,253 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel +from transformers.pytorch_utils import Conv1D + + +class GPT2MLP(nn.Module): + + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +# The reason Why we don't import GPT2Attention from transformers directly is that: +# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time, +# so we have to build the customized GPT2Attention class and remove the conditional branch manually. +# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new +# order is same as megatron-lm gpt model. +class GPT2Attention(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + self.scale_attn_weights = config.scale_attn_weights + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (value.size(-1)**0.5) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.type(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + + qkv = self.c_attn(hidden_states) + query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) + present = (key, value) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + return attn_output + + +class GPT2Block(nn.Module): + + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + ) + # residual connection + hidden_states = attn_outputs + residual + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + return hidden_states + + +class GPT2Model(GPT2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + + device = input_ids.device + + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + + hidden_states = inputs_embeds + position_embeds + + output_shape = input_shape + (hidden_states.size(-1),) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i]) + hidden_states = outputs + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + + return hidden_states + + +class GPT2LMHeadModel(GPT2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): + transformer_outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + ) + lm_logits = self.lm_head(transformer_outputs) + + return lm_logits + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) diff --git a/examples/language/gpt/auto_parallel_with_gpt/requirements.txt b/examples/language/gpt/auto_parallel_with_gpt/requirements.txt new file mode 100644 index 000000000..ff046ad1c --- /dev/null +++ b/examples/language/gpt/auto_parallel_with_gpt/requirements.txt @@ -0,0 +1,4 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 +transformers >= 4.231 +PuLP >= 2.7.0