mirror of https://github.com/hpcaitech/ColossalAI
parent
24651fdd4f
commit
ddcf58cacf
@ -1,73 +0,0 @@
|
||||
# 🗄 Device
|
||||
|
||||
## 📚 Table of Contents
|
||||
|
||||
- [🗄 Device](#-device)
|
||||
- [📚 Table of Contents](#-table-of-contents)
|
||||
- [🔗 Introduction](#-introduction)
|
||||
- [📝 Design](#-design)
|
||||
- [🔨 Usage](#-usage)
|
||||
|
||||
## 🔗 Introduction
|
||||
|
||||
This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
|
||||
|
||||
## 📝 Design
|
||||
|
||||
|
||||
This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
|
||||
|
||||
|
||||
## 🔨 Usage
|
||||
|
||||
- Create a device mesh
|
||||
|
||||
```python
|
||||
# this is the list of global ranks involved in the device mesh
|
||||
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
|
||||
physical_mesh_id = torch.arange(4)
|
||||
mesh_shape = [2, 2]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
```
|
||||
|
||||
- View the mesh
|
||||
|
||||
|
||||
```python
|
||||
# view the mesh shape
|
||||
# expect output
|
||||
# [2, 2]
|
||||
print(device_mesh.shape)
|
||||
|
||||
|
||||
# view the logical mesh with global ranks
|
||||
# expect output
|
||||
# [
|
||||
# [0, 1],
|
||||
# [2, 3]
|
||||
# ]
|
||||
print(device_mesh.logical_mesh_id)
|
||||
|
||||
# view the number of devices in the mesh
|
||||
# expect output
|
||||
# 4
|
||||
print(device_mesh.num_devices)
|
||||
|
||||
```
|
||||
|
||||
- Initialize the process group
|
||||
|
||||
```python
|
||||
# intialize process group
|
||||
device_mesh.init_logical_process_group()
|
||||
|
||||
|
||||
# get the process group for a rank with respect to an axis
|
||||
# this is the process group involving global ranks 0 and 2
|
||||
print(device_mesh.get_process_group(axis=0, global_rank=0))
|
||||
|
||||
# get the ranks in the process with respect to an axis
|
||||
# expect output
|
||||
# [0, 2]
|
||||
print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))
|
||||
```
|
@ -1,97 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
except:
|
||||
fused_mix_prec_layer_norm_cuda = None
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
r"""Layernorm
|
||||
|
||||
Args:
|
||||
input: input matrix.
|
||||
weight: weight matrix.
|
||||
bias: bias matrix.
|
||||
normalized_shape: input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps: a value added to the denominator for numerical stability
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.eps = eps
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
|
||||
bias_, ctx.eps)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with asynchronous communication in backprop.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
|
||||
total_input = input
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|
@ -1,105 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class DistCrossEntropy(Function):
|
||||
r"""
|
||||
Overwrite the forward and backward function to calculate the cross entropy loss before gather
|
||||
|
||||
Args:
|
||||
Function (:class:`torch.autograd.Function`): default
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
|
||||
r"""
|
||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||
and can be rewrite as:
|
||||
loss = log(sum(exp(x[i])) - x[class]
|
||||
|
||||
To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i]
|
||||
|
||||
Args:
|
||||
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
||||
[batch_size, seq_len, vocab_size]
|
||||
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
||||
[batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The cross entropy loss
|
||||
"""
|
||||
# get the max
|
||||
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
||||
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)
|
||||
|
||||
# minus the max to avoid the result of sum of exp is too large and the log is nan
|
||||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
# mask the target in the local device
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
|
||||
# [down, up) => false, other device and -100 => true
|
||||
delta = (global_vocab_size + world_size - 1) // world_size
|
||||
down_shreshold = rank * delta
|
||||
up_shreshold = down_shreshold + delta
|
||||
mask = (target < down_shreshold) | (target >= up_shreshold)
|
||||
masked_target = target.clone() - down_shreshold
|
||||
masked_target[mask] = 0
|
||||
|
||||
# reshape the logist and target
|
||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||
# reshape the labels to [bath_size * seq_len]
|
||||
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
|
||||
# extract the x[class] and set the x[other device] to zero
|
||||
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device),
|
||||
masked_target_1d]
|
||||
pred_logits_1d = pred_logits_1d.clone().contiguous()
|
||||
pred_logits = pred_logits_1d.view_as(target)
|
||||
pred_logits[mask] = 0.0
|
||||
|
||||
# allreduce the get all x(i,y)
|
||||
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
|
||||
exp_logits = vocab_logits
|
||||
torch.exp(vocab_logits, out=exp_logits)
|
||||
sum_exp_logits = torch.sum(exp_logits, dim=-1)
|
||||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)
|
||||
|
||||
# calculate the loss
|
||||
# loss = log(sum(exp(x[i]))) - x[class]
|
||||
loss = torch.log(sum_exp_logits) - pred_logits
|
||||
loss = torch.sum(loss).div_(loss.numel())
|
||||
|
||||
# caculate the softmax
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# retrieve the saved tensors
|
||||
exp_logits, mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# use exp logits as the input grad
|
||||
grad_logits = exp_logits
|
||||
partion_vocab_size = grad_logits.shape[-1]
|
||||
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
||||
|
||||
update = 1.0 - mask.view(-1).float()
|
||||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||
return grad_logits, None, None
|
||||
|
||||
|
||||
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels)
|
@ -1,58 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SeedManager:
|
||||
"""
|
||||
This class is a random state manager to change random state for different random seed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
original_state = torch.cuda.get_rng_state()
|
||||
seed = int(f"{int(time.time())}{os.environ['RANK']}")
|
||||
torch.cuda.manual_seed(int(seed))
|
||||
self.dropout_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(original_state)
|
||||
|
||||
def set_mode(self, rng_state):
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
|
||||
def get_current_mode(self):
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
return current_state
|
||||
|
||||
@contextmanager
|
||||
def dropout_mode(self):
|
||||
"""
|
||||
This is a context manager to change the dropout state and recover the original state.
|
||||
|
||||
Usage:
|
||||
::
|
||||
>>> with _seed_manager.dropout_mode():
|
||||
>>> input = super().forward(input)
|
||||
"""
|
||||
try:
|
||||
current_mode = self.get_current_mode()
|
||||
yield self.set_mode(self.dropout_state)
|
||||
finally:
|
||||
self.dropout_state = self.get_current_mode()
|
||||
self.set_mode(current_mode)
|
||||
|
||||
|
||||
_seed_manager = SeedManager()
|
||||
|
||||
|
||||
class Dropout1D(nn.Dropout):
|
||||
|
||||
def __init__(self, p=0.5, inplace=False):
|
||||
super().__init__(p, inplace)
|
||||
|
||||
def forward(self, input):
|
||||
with _seed_manager.dropout_mode():
|
||||
input = super().forward(input)
|
||||
return input
|
File diff suppressed because it is too large
Load Diff
@ -1,67 +0,0 @@
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import BertForMaskedLM
|
||||
from transformers.models.bert.modeling_bert import MaskedLMOutput
|
||||
|
||||
from ..layer.dist_crossentropy import applyDistCrossEntropy
|
||||
|
||||
|
||||
class BertForMaskedLM_(BertForMaskedLM):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
# print("[Inject OK] Injected forward method")
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
masked_lm_loss = None
|
||||
|
||||
if labels is not None:
|
||||
masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels)
|
||||
# if labels is not None:
|
||||
# loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
@ -1,58 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def build_policies():
|
||||
r"""
|
||||
Build the policies for the model
|
||||
|
||||
Return:
|
||||
The dict for the policies
|
||||
"""
|
||||
auto_policy_dict = {}
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from .bert import BertForMaskedLMPolicy
|
||||
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
|
||||
|
||||
from transformers import BertForSequenceClassification
|
||||
|
||||
from .bert import BertForSequenceClassificationPolicy
|
||||
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
||||
|
||||
from transformers import GPT2Model
|
||||
|
||||
from .gpt2 import GPT2Policy
|
||||
auto_policy_dict[GPT2Model] = GPT2Policy
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
from .gpt2 import GPT2LMHeadModelPolicy
|
||||
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
||||
|
||||
return auto_policy_dict
|
||||
|
||||
|
||||
def get_autopolicy(model: nn.Module):
|
||||
r"""
|
||||
Return the auto policy for the model
|
||||
|
||||
Args:
|
||||
model (:class:`nn.Module`): The model to get the auto policy
|
||||
|
||||
Return:
|
||||
:class:`Policy`: The auto policy for the model
|
||||
"""
|
||||
auto_policy_dict = build_policies()
|
||||
policy = auto_policy_dict.get(model.__class__, None)
|
||||
if policy is None:
|
||||
raise NotImplementedError(
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
|
||||
)
|
||||
return policy
|
||||
|
||||
|
||||
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
|
||||
# model = BertForPreTraining
|
||||
# policy = get_autopolicy(model)
|
||||
# print(policy)
|
@ -1,217 +0,0 @@
|
||||
# part of code modified from https://github.com/tunib-ai/parallelformers
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class Argument:
|
||||
r"""
|
||||
The argument class for the policy
|
||||
|
||||
Args:
|
||||
attr_dict (Dict[str, Any]): The dict for the param setting
|
||||
param_funcs (:class:`List[Callable]`): The list for the param functions
|
||||
"""
|
||||
attr_dict: Dict[str, Any]
|
||||
param_funcs: List[Callable]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Layer:
|
||||
r"""
|
||||
The layer object for the policy
|
||||
|
||||
Args:
|
||||
weight (str): The weight suffix of the layer
|
||||
bias (str): The bias suffix of the layer
|
||||
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
|
||||
ignore (bool): Whether to ignore this layer if it is not in the model
|
||||
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
|
||||
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
|
||||
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
|
||||
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
|
||||
each device should have a part of Q, K and V weight.
|
||||
"""
|
||||
weight: str = None
|
||||
bias: str = None
|
||||
replace_layer: Any = None
|
||||
ignore: bool = False
|
||||
reversed: bool = False
|
||||
n_cast: int = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Col_Layer(Layer):
|
||||
r"""
|
||||
Class for col shard layer in MegatronLM
|
||||
|
||||
Args:
|
||||
gather_output (bool): Whether to gather the output of the layer
|
||||
"""
|
||||
gather_output: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Row_Layer(Layer):
|
||||
r"""
|
||||
Class for col shard layer in MegatronLM
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Policy():
|
||||
r"""
|
||||
The base class for all the policies
|
||||
For each different model, it should have a different policy class, like BertPolicy for Bert Model
|
||||
or OPTPolicy for OPT model.
|
||||
AutoPolicy:
|
||||
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
|
||||
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
|
||||
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
|
||||
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
|
||||
and overwrite the method like ``inject_policy`` to modify the forward and backward process.
|
||||
|
||||
CustomPolicy:
|
||||
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
|
||||
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
|
||||
class for the example.
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
|
||||
r"""
|
||||
Return the dict for the modify policy, the key is the original layer class and the value is the
|
||||
argument for the modify layer
|
||||
|
||||
Args:
|
||||
model_config (:class:`tansformer.Config`): The config of transformer model
|
||||
shard_config (:class:`ShardConfig`): The config for sharding model
|
||||
|
||||
Return:
|
||||
Dict for the modify policy,
|
||||
::
|
||||
{
|
||||
origin layer class1 (nn.Module): Argument(
|
||||
attr_dict = {
|
||||
argument1: value1,
|
||||
argument2: value2,
|
||||
...
|
||||
},
|
||||
param_funcs = [
|
||||
staticmethod1,
|
||||
staticmethod2,
|
||||
...
|
||||
]
|
||||
),
|
||||
origin layer class2 (nn.Module): Argument(
|
||||
attr_dict = {
|
||||
argument1: value1,
|
||||
argument2: value2,
|
||||
...
|
||||
},
|
||||
param_funcs = [
|
||||
staticmethod1,
|
||||
staticmethod2,
|
||||
...
|
||||
]
|
||||
),
|
||||
...
|
||||
}
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
r"""
|
||||
Return the dict for the inject model
|
||||
|
||||
Return:
|
||||
The injected model, key is the original model and value is the new shardmodel
|
||||
::
|
||||
(OrignModel, CustomModel)
|
||||
in `CustomModel`, we can overwrite the forward and backward process
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Dict:
|
||||
r"""
|
||||
Return the dict for the binding model
|
||||
|
||||
Return:
|
||||
This method should return the binding relationship for some layers share the weight or bias,
|
||||
the key and value is the suffix of the weight or bias of the model
|
||||
::
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
r"""
|
||||
Attention qkv layer
|
||||
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
|
||||
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
|
||||
in ``Layer`` object can refer to the ``Layer`` class.
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object, each layer is the new
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
r"""
|
||||
Attention output projection layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
r"""
|
||||
h -> 4h mlp layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
r"""
|
||||
4h -> h mlp layer
|
||||
|
||||
Returns:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
r"""
|
||||
Partially slice the embedding layer
|
||||
|
||||
Return:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
r"""
|
||||
Partially slice the embedding layer
|
||||
|
||||
Return:
|
||||
List[Layer]: List of layer object
|
||||
"""
|
||||
return None
|
@ -1,170 +0,0 @@
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
return {
|
||||
BertLayer:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size": config.hidden_size // world_size,
|
||||
"crossattention.self.all_head_size": config.hidden_size // world_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
},
|
||||
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
|
||||
BertEmbeddings:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.embedding,
|
||||
]),
|
||||
BertLMPredictionHead:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Dict:
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
weight="attention.self.query.weight",
|
||||
bias="attention.self.query.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="attention.self.key.weight",
|
||||
bias="attention.self.key.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="attention.self.value.weight",
|
||||
bias="attention.self.value.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="crossattention.self.query.weight",
|
||||
bias="crossattention.self.query.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="crossattention.self.key.weight",
|
||||
bias="crossattention.self.key.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
Col_Layer(
|
||||
weight="crossattention.self.value.weight",
|
||||
bias="crossattention.self.value.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
return [
|
||||
Row_Layer(
|
||||
weight="attention.output.dense.weight",
|
||||
bias="attention.output.dense.bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
Row_Layer(
|
||||
weight="crossattention.output.dense.weight",
|
||||
bias="crossattention.output.dense.bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
ignore=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
weight="intermediate.dense.weight",
|
||||
bias="intermediate.dense.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
return [
|
||||
Row_Layer(
|
||||
weight="output.dense.weight",
|
||||
bias="output.dense.bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
return [Col_Layer(
|
||||
weight="word_embeddings.weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
weight="decoder.weight",
|
||||
bias="decoder.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
# gather_output=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
|
||||
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
return (BertForMaskedLM, BertForMaskedLM_)
|
||||
|
||||
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Dict:
|
||||
return {}
|
||||
|
||||
|
||||
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
# _ = BertForMaskedLMPolicy(model)
|
||||
# print(isinstance(model,list(_.inject_policy().keys())[0]))
|
@ -1,118 +0,0 @@
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
||||
|
||||
|
||||
class GPT2Policy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
return {
|
||||
GPT2Model:
|
||||
Argument(attr_dict={}, param_funcs=[
|
||||
GPT2Policy.embedding,
|
||||
]),
|
||||
GPT2Block:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. reduce hidden size
|
||||
"attn.embed_dim": config.hidden_size // world_size,
|
||||
"attn.split_size": config.hidden_size // world_size,
|
||||
"crossattention.embed_dim": config.hidden_size // world_size,
|
||||
"crossattention.split_size": config.hidden_size // world_size,
|
||||
# 2. reduce number of heads
|
||||
"attn.num_heads": config.num_attention_heads // world_size,
|
||||
"crossattention.num_heads": config.num_attention_heads // world_size,
|
||||
},
|
||||
param_funcs=[
|
||||
GPT2Policy.attn_in,
|
||||
GPT2Policy.attn_out,
|
||||
GPT2Policy.mlp_in,
|
||||
GPT2Policy.mlp_out,
|
||||
]),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def attn_in() -> List:
|
||||
return [
|
||||
Col_Layer(weight="attn.c_attn.weight",
|
||||
bias="attn.c_attn.bias",
|
||||
n_cast=3,
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Col),
|
||||
Col_Layer(weight="crossattention.c_attn.weight",
|
||||
bias="crossattention.c_attn.bias",
|
||||
n_cast=2,
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Col),
|
||||
Col_Layer(weight="crossattention.q_attn.weight",
|
||||
bias="crossattention.q_attn.bias",
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Col)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
return [
|
||||
Row_Layer(weight="attn.c_proj.weight",
|
||||
bias="attn.c_proj.bias",
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Row),
|
||||
Row_Layer(weight="crossattention.c_proj.weight",
|
||||
bias="crossattention.c_proj.bias",
|
||||
reversed=True,
|
||||
ignore=True,
|
||||
replace_layer=col_nn.Linear1D_Row)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
return [
|
||||
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
return [
|
||||
Row_Layer(weight="mlp.c_proj.weight",
|
||||
bias="mlp.c_proj.bias",
|
||||
reversed=True,
|
||||
replace_layer=col_nn.Linear1D_Row)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
|
||||
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = GPT2Policy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
|
||||
GPT2LMHeadModelPolicy.unembedding,
|
||||
]),
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
return [
|
||||
Col_Layer(weight="lm_head.weight",
|
||||
bias="lm_head.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True)
|
||||
]
|
@ -1,5 +0,0 @@
|
||||
from .shard_config import ShardConfig
|
||||
from .sharder import ModelSharder, shard_model
|
||||
from .slicer import Slicer
|
||||
|
||||
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']
|
@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
__all__ = ['ShardConfig']
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardConfig:
|
||||
"""
|
||||
The config for sharding the huggingface model for test
|
||||
"""
|
||||
rank: int
|
||||
fp16: bool = True
|
||||
num_gpus: int = 2
|
||||
world_size: int = 2
|
||||
backend = "nccl"
|
||||
verbose: str = 'simple'
|
||||
seed: int = None
|
||||
require_grad: bool = False
|
||||
master_addr: str = "127.0.0.1"
|
||||
master_port: int = 29500
|
@ -1,266 +0,0 @@
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from ..policies.autopolicy import get_autopolicy
|
||||
from ..policies.basepolicy import Policy
|
||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
||||
from .shard_config import ShardConfig
|
||||
from .slicer import Slicer
|
||||
|
||||
__all__ = ['ModelSharder', 'shard_model']
|
||||
|
||||
|
||||
class ModelSharder(object):
|
||||
r"""
|
||||
Shard the original huggingface model according to the policy
|
||||
|
||||
Args:
|
||||
policy (:class:`Policy`): The policy to shard the model
|
||||
model (:class:`torch.Module`): The model to shard
|
||||
shard_config: The setting of distributed model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
policy: Policy,
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||
self.slicer = Slicer(shard_config)
|
||||
self.shard_config = shard_config
|
||||
self.model_config = self.model.config
|
||||
|
||||
def shard(self) -> None:
|
||||
self.reshape_embedding()
|
||||
self.inject_model(self.model)
|
||||
self.replace_layer(self.model)
|
||||
self.bind_layer(self.model)
|
||||
|
||||
def reshape_embedding(self,) -> None:
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
vocab_size = self.model_config.vocab_size
|
||||
world_size = self.shard_config.world_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.model_config = self.model.config
|
||||
|
||||
def inject_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
r"""
|
||||
Replace the model to policy defined model
|
||||
Mainly modify the forward and backward to fit distributed model
|
||||
|
||||
e.g.
|
||||
::
|
||||
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||
"""
|
||||
inject_policy = self.policy.inject_policy()
|
||||
|
||||
if inject_policy is None:
|
||||
return
|
||||
org_model_cls = inject_policy[0]
|
||||
shard_model_cls = inject_policy[1]
|
||||
|
||||
if model.__class__ == org_model_cls:
|
||||
for key in shard_model_cls.__dict__.keys():
|
||||
if hasattr(model.__class__, key):
|
||||
setattr(
|
||||
model.__class__,
|
||||
key,
|
||||
getattr(shard_model_cls, key),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{model.__class__} is not implemented so far")
|
||||
|
||||
def replace_layer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
) -> None:
|
||||
r"""
|
||||
Replace the layer according to the policy, and replace the layer one by one
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): The layer to shard
|
||||
"""
|
||||
argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size)
|
||||
for argument_policy in argument_policies.items():
|
||||
origin_layer_cls = argument_policy[0]
|
||||
attr_dict = argument_policy[1].attr_dict
|
||||
param_funcs = argument_policy[1].param_funcs
|
||||
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
|
||||
|
||||
def traverse_replace_layer(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
origin_cls: nn.Module,
|
||||
attr_dict: Dict[str, Any],
|
||||
param_funcs: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Reverse the replace layer operation
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The object of layer to shard
|
||||
origin_cls (:class:`transformers.model`): The origin layer class
|
||||
attr_dict (Dict): The attribute dict to modify
|
||||
policy_cls (:class:`Policy`): The policy class
|
||||
"""
|
||||
if layer.__class__ == origin_cls:
|
||||
for k, v in attr_dict.items():
|
||||
setattr_(layer, k, v, ignore=True)
|
||||
self.shard_one_layer(layer, param_funcs)
|
||||
for name, child in layer.named_children():
|
||||
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
|
||||
return layer
|
||||
|
||||
def shard_one_layer(
|
||||
self,
|
||||
org_layer: nn.Module,
|
||||
param_funcs: List[Callable],
|
||||
) -> None:
|
||||
r"""
|
||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||
|
||||
Args:
|
||||
org_layer (:class:`torch.nn.Module`): The origin layer object to shard
|
||||
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
|
||||
|
||||
"""
|
||||
for func in param_funcs:
|
||||
policy_layers = func()
|
||||
for policy_layer in policy_layers:
|
||||
weight = None
|
||||
bias = None
|
||||
weight_attr = policy_layer.weight
|
||||
bias_attr = policy_layer.bias
|
||||
replace_layer_cls = policy_layer.replace_layer
|
||||
ignore = policy_layer.ignore
|
||||
n_cast = policy_layer.n_cast
|
||||
reversed = policy_layer.reversed
|
||||
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||
gather_output = policy_layer.gather_output
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
weight = getattr_(org_layer, weight_attr)
|
||||
elif not ignore:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}")
|
||||
|
||||
if bias_attr is not None:
|
||||
if hasattr_(org_layer, bias_attr):
|
||||
bias = getattr_(org_layer, bias_attr)
|
||||
elif not ignore:
|
||||
raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}")
|
||||
|
||||
# dont have the attribute in policy, and ignore is true
|
||||
if weight is None and bias is None and ignore:
|
||||
continue
|
||||
|
||||
# set the sliced weight and bias to the new nn_col layer
|
||||
assert weight is not None or bias is not None
|
||||
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
|
||||
|
||||
# slice weight and bias
|
||||
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
|
||||
|
||||
# create new object to replace the origin layer
|
||||
if replace_layer_cls is not None:
|
||||
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
|
||||
if replace_layer_cls.__name__ == "Linear1D_Row":
|
||||
replace_layer = replace_layer_cls(weight.shape[1],
|
||||
weight.shape[0],
|
||||
bias=False if bias is None else True)
|
||||
elif replace_layer_cls.__name__ == "Linear1D_Col":
|
||||
replace_layer = replace_layer_cls(weight.shape[0],
|
||||
weight.shape[1],
|
||||
bias=False if bias is None else True,
|
||||
gather_output=gather_output)
|
||||
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding):
|
||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
||||
getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True))
|
||||
setattr_(org_layer, layer_attr, replace_layer, ignore=ignore)
|
||||
self.set_param(replace_layer, weight, bias)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far")
|
||||
# do not replace the layer object, just replace the weight and bias
|
||||
else:
|
||||
self.set_param(org_layer, layer_attr, weight, bias)
|
||||
|
||||
def set_param(self,
|
||||
layer: Any,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
layer_attr: str = "") -> None:
|
||||
r"""
|
||||
Reset the weight and bias of the layer object
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The layer object
|
||||
layer_attr (str): The attribute name of the layer
|
||||
weight (:class:`torch.Tensor`): The weight of the layer
|
||||
bias (:class:`torch.Tensor`): The bias of the layer
|
||||
"""
|
||||
assert weight is not None or bias is not None
|
||||
if weight is not None:
|
||||
setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous()))
|
||||
self.set_layer_size(layer, layer_attr, weight.shape)
|
||||
if bias is not None:
|
||||
setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous()))
|
||||
|
||||
def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None:
|
||||
r"""
|
||||
Set the layer attribute
|
||||
|
||||
Args:
|
||||
layer (:class:`torch.nn.Module`): The layer object
|
||||
layer_attr (str): The attribute name of the layer
|
||||
size (:class:`torch.Size`): The size of the tensor
|
||||
"""
|
||||
# Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features
|
||||
attrs = ["out_features", "in_features"]
|
||||
for i, attr in enumerate(attrs):
|
||||
if hasattr_(layer, f"{layer_attr}.{attr}"):
|
||||
setattr_(layer, f"{layer_attr}.{attr}", size[i])
|
||||
|
||||
def bind_layer(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
Bind the layer according to the binding policy
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): The shard model
|
||||
"""
|
||||
binding_map = self.policy.binding_policy()
|
||||
if binding_map is None:
|
||||
return
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(model, k, param)
|
||||
setattr_(model, v, param)
|
||||
|
||||
|
||||
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
|
||||
r"""
|
||||
The function is used to shard the PyTorch model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Model`): the origin huggingface model
|
||||
shard_config (`ShardConfig`): the config for distribute information
|
||||
policy (`Policy`): the custom policy for sharding
|
||||
"""
|
||||
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
|
||||
sharder.shard()
|
||||
return model
|
@ -1,161 +0,0 @@
|
||||
import torch
|
||||
|
||||
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
|
||||
from .shard_config import ShardConfig
|
||||
|
||||
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
|
||||
|
||||
|
||||
class Slicer():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shardconfig: ShardConfig #TODO
|
||||
) -> None:
|
||||
self.shardconfig = shardconfig
|
||||
|
||||
def slice_weight_bias(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
policy_layer_cls: Layer,
|
||||
n_cast: int = None,
|
||||
reversed: bool = False,
|
||||
):
|
||||
r"""
|
||||
Slice the weight and bias according to policy layer cls
|
||||
``Layer`` -> do nothing
|
||||
``Col_Layer`` -> slice the weight and bias along dim 1
|
||||
``Row_Layer`` -> slice the weight along dim 0 and do not slice bias
|
||||
|
||||
Args:
|
||||
weight (:class:`torch.nn.Module`): The weight of the layer
|
||||
bias: (:class:`torch.nn.Module`): The bias of the layer
|
||||
policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
|
||||
"""
|
||||
if policy_layer_cls == Layer:
|
||||
return weight, bias
|
||||
|
||||
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
|
||||
# print(weight.shape, dim)
|
||||
if policy_layer_cls == Col_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
bias = self.slice_tensor(bias, 0, True)
|
||||
elif policy_layer_cls == Row_Layer:
|
||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||
else:
|
||||
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
||||
if reversed:
|
||||
weight = weight.transpose(0, 1).contiguous()
|
||||
return weight, bias
|
||||
|
||||
def slice_tensor(
|
||||
self,
|
||||
tensor_in: torch.Tensor,
|
||||
dim: int,
|
||||
is_bias: bool,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice tensor according to the config
|
||||
|
||||
Args:
|
||||
tensor_in (:class:`torch.Tensor`): The tensor to slice
|
||||
dim (int): The dimension to slice
|
||||
is_bias (bool): Whether the tensor is bias
|
||||
"""
|
||||
if tensor_in is None:
|
||||
return None
|
||||
if not is_bias:
|
||||
return self.slice_2d(tensor_in, dim, n_cast)
|
||||
else:
|
||||
return self.slice_1d(tensor_in, n_cast)
|
||||
|
||||
def slice_2d(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
dim: int,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the 2D tensor
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
dim (int): The dimension to slice
|
||||
"""
|
||||
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
|
||||
if dim == 0:
|
||||
return self.slice_row(tensor, n_cast)
|
||||
elif dim == 1:
|
||||
return self.slice_col(tensor, n_cast)
|
||||
|
||||
def slice_1d(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the 1D tensor
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=0).contiguous()
|
||||
|
||||
def slice_col(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the tensor in column
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
|
||||
"""
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=0).contiguous()
|
||||
|
||||
def slice_row(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
n_cast: int = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Slice the tensor in column
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): The tensor to slice
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
if n_cast is None:
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||
else:
|
||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
||||
chunk_list = [
|
||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||
]
|
||||
return torch.cat(chunk_list, dim=1).contiguous()
|
@ -1 +0,0 @@
|
||||
parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d'))
|
@ -1,50 +0,0 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument("--module", type=str, default='distloss')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_dist_crossentropy():
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True)
|
||||
labels = torch.randint(8, (1, 4)).repeat(2, 1)
|
||||
|
||||
pred_ = pred.view(-1, 8)
|
||||
labels_ = labels.view(-1)
|
||||
loss = F.cross_entropy(pred_, labels_)
|
||||
loss.backward()
|
||||
print(f"normal loss:{loss}")
|
||||
|
||||
pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])]
|
||||
loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda'))
|
||||
loss.backward()
|
||||
print(f"dist loss:{loss}")
|
||||
|
||||
|
||||
def test_dropout():
|
||||
input = torch.randn(5, 4).to("cuda")
|
||||
m = Dropout1D(p=0.2).to("cuda")
|
||||
for i in range(2):
|
||||
print(f"Output: {m(input)}")
|
||||
print(torch.randn(1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
colossalai.launch_from_torch(config={})
|
||||
if args.module == 'distloss':
|
||||
test_dist_crossentropy()
|
||||
elif args.module == 'dropout':
|
||||
test_dropout()
|
||||
else:
|
||||
print("not implemented yet")
|
@ -1,124 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument("--mode", type=str, default='inference')
|
||||
parser.add_argument("--save_model", action='store_true')
|
||||
parser.add_argument("--model", type=str, default='bert-base-uncased')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_data(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
# tokenizer.pad_token_id = 0
|
||||
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
|
||||
# datasets=load_dataset("yelp_review_full")
|
||||
tokenized_datasets = datasets.map(
|
||||
lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True)
|
||||
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
|
||||
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
|
||||
tokenized_datasets.set_format("torch")
|
||||
|
||||
train_dataset = tokenized_datasets["train"]
|
||||
test_dataset = tokenized_datasets["test"]
|
||||
|
||||
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
|
||||
eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def inference(model: nn.Module, args):
|
||||
print(model)
|
||||
# print(model.wte.weight.shape)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
tokenizer.pad_token_id = 0
|
||||
token = "Hello, my dog is cute"
|
||||
inputs = tokenizer(token, return_tensors="pt")
|
||||
inputs.to("cuda")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
outputs = model(**inputs)
|
||||
print(outputs[0])
|
||||
|
||||
|
||||
def train(model: nn.Module, args, num_epoch: int = 3):
|
||||
train_dataloader, eval_dataloader = load_data(args)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
num_training = num_epoch * len(train_dataloader)
|
||||
progress_bar = tqdm(range(num_training))
|
||||
lr_scheduler = get_scheduler(name="linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=num_training)
|
||||
best_test_loss = float("inf")
|
||||
model.to("cuda")
|
||||
model.train()
|
||||
for epoch in range(num_epoch):
|
||||
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
batch = {k: v.to('cuda') for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
progress_bar.update(1)
|
||||
train_loss = loss
|
||||
|
||||
loss = 0.0
|
||||
for batch in eval_dataloader:
|
||||
batch = {k: v.to('cuda') for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
# loss = outputs.loss
|
||||
assert not torch.isnan(outputs.loss), f"{batch}"
|
||||
loss += outputs.loss.item()
|
||||
# loss = criterion(outputs.logits, batch["input_ids"])
|
||||
test_loss = loss / len(eval_dataloader)
|
||||
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
|
||||
if args.save_model and test_loss < best_test_loss:
|
||||
best_test_loss = test_loss
|
||||
torch.save(model.state_dict(), "./checkpoints/best_model.pth")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
if args.model == 'bert-base-uncased':
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
elif args.model == 'gpt2':
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
else:
|
||||
raise AttributeError("model not supported")
|
||||
shard_config = ShardConfig(
|
||||
rank=int(str(get_current_device()).split(':')[-1]),
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
)
|
||||
sharded_model = shard_model(model, shard_config)
|
||||
|
||||
if args.mode == "train":
|
||||
train(sharded_model, args)
|
||||
elif args.mode == "inference":
|
||||
inference(sharded_model, args)
|
||||
else:
|
||||
raise NotImplementedError
|
@ -1,58 +0,0 @@
|
||||
def hasattr_(obj, attr: str):
|
||||
r"""
|
||||
Check whether the object has the multi sublevel attr
|
||||
|
||||
Args:
|
||||
obj (object): The object to check
|
||||
attr (str): The multi level attr to check
|
||||
"""
|
||||
attrs = attr.split('.')
|
||||
for a in attrs:
|
||||
try:
|
||||
obj = getattr(obj, a)
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def setattr_(obj, attr: str, value, ignore: bool = False):
|
||||
r"""
|
||||
Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist
|
||||
|
||||
Args:
|
||||
obj (object): The object to set
|
||||
attr (str): The multi level attr to set
|
||||
value (Any): The value to set
|
||||
ignore (bool): Whether to ignore when the attr doesn't exist
|
||||
"""
|
||||
|
||||
attrs = attr.split('.')
|
||||
for a in attrs[:-1]:
|
||||
try:
|
||||
obj = getattr(obj, a)
|
||||
except AttributeError:
|
||||
if ignore:
|
||||
return
|
||||
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
def getattr_(obj, attr: str, ignore: bool = None):
|
||||
r"""
|
||||
Get the object's multi sublevel attr
|
||||
|
||||
Args:
|
||||
obj (object): The object to set
|
||||
attr (str): The multi level attr to set
|
||||
ignore (bool): Whether to ignore when the attr doesn't exist
|
||||
"""
|
||||
|
||||
attrs = attr.split('.')
|
||||
for a in attrs:
|
||||
try:
|
||||
obj = getattr(obj, a)
|
||||
except AttributeError:
|
||||
if ignore:
|
||||
return None
|
||||
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
||||
return obj
|
@ -1,103 +0,0 @@
|
||||
# 🔢 Distributed Tensor
|
||||
|
||||
## 📚 Table of Contents
|
||||
|
||||
- [🔢 Distributed Tensor](#-distributed-tensor)
|
||||
- [📚 Table of Contents](#-table-of-contents)
|
||||
- [🔗 Introduction](#-introduction)
|
||||
- [📝 Design](#-design)
|
||||
- [🔨 Usage](#-usage)
|
||||
- [🎈 Progress Log](#-progress-log)
|
||||
|
||||
## 🔗 Introduction
|
||||
|
||||
Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training.
|
||||
It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor.
|
||||
|
||||
## 📝 Design
|
||||
|
||||
Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension.
|
||||
|
||||
Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below:
|
||||
|
||||
|
||||
```text
|
||||
[1, 2, 3, 4 ]
|
||||
A = [4, 5, 6, 7 ]
|
||||
[8, 9, 10, 11]
|
||||
[12, 13, 14, 15]
|
||||
```
|
||||
|
||||
`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology.
|
||||
|
||||
```text
|
||||
| --------------------—————————————————————-|
|
||||
| | |
|
||||
| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] |
|
||||
| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
| | |
|
||||
| [8, 9, 10, 11] | [8, 9, 10, 11] |
|
||||
| [12, 13, 14, 15] | [12, 13, 14, 15] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
```
|
||||
|
||||
`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology.
|
||||
|
||||
```text
|
||||
| --------------------—————————————————————-|
|
||||
| | |
|
||||
| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
| | |
|
||||
| [8, 9, 10, 11] | [12, 13, 14, 15] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
```
|
||||
|
||||
## 🔨 Usage
|
||||
|
||||
A sample API usage is given below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor import DTensor, ShardingSpec
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
# define your device mesh
|
||||
# assume you have 4 GPUs
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# define a tensor
|
||||
a = torch.rand(16, 32).cuda()
|
||||
|
||||
# create sharding spec for the tensor
|
||||
# assume the sharding spec is [S0, R]
|
||||
dim_partition_dict = {0: [0]}
|
||||
sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
|
||||
|
||||
# create a distributed tensor
|
||||
d_tensor = DTensor(a, device_mesh, sharding_spec)
|
||||
print(d_tensor)
|
||||
|
||||
global_tensor = d_tensor.to_global()
|
||||
print(global_tensor)
|
||||
```
|
||||
|
||||
|
||||
## 🎈 Progress Log
|
||||
|
||||
- [x] Support layout conversion
|
||||
- [x] Support sharding on 2D device mesh
|
||||
- [ ] Support sharding on 3D device mesh
|
||||
- [ ] Support sharding 4D device mesh
|
||||
- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.)
|
@ -1,4 +0,0 @@
|
||||
from .d_tensor import DTensor
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['DTensor', 'ShardingSpec']
|
@ -1,71 +0,0 @@
|
||||
# Lazy initialization
|
||||
|
||||
Author: Hongxin Liu
|
||||
|
||||
**Prerequisite**
|
||||
- [Booster API](../basics/booster_api.md)
|
||||
- [Booster Plugins](../basics/booster_plugins.md)
|
||||
- [Booster Checkpoint](../basics/booster_checkpoint.md)
|
||||
|
||||
**Related discussion**
|
||||
- [Lazy initialization of model](https://github.com/hpcaitech/ColossalAI/discussions/3124)
|
||||
|
||||
## Introduction
|
||||
|
||||
LazyTensor allows DL framework (PyTorch) to execute operations lazily, by storing all operations related to it and reruning them when it's required to be materialized.
|
||||
|
||||
LazyInit defers model initialization and it's based on LazyTensor.
|
||||
|
||||
This is especially useful when we use model parallelism to train large models, in which case the model cannot fit in GPU memory. Through this, we can initialize model tensors using meta tensor and do static analysis to get shard strategy. And then materialize each tensor and apply the shard strategy. The static analysis can be omitted if the shard strategy is known in advance.
|
||||
|
||||
## Usage
|
||||
|
||||
You may use lazy initialization when using Gemini, tensor parallelism, pipeline parallelism, and auto-parallelism. In other cases, you may not need to use lazy initialization.
|
||||
|
||||
Gemini is compatible with lazy initialization. You can use them together directly.
|
||||
|
||||
```python
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from torch.nn import Linear
|
||||
import colossalai
|
||||
|
||||
colossalai.launch_from_torch({})
|
||||
|
||||
plugin = GeminiPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
with LazyInitContext():
|
||||
model = Linear(10, 10)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
model, optimizer, *_ = booster.boost(model, optimizer)
|
||||
```
|
||||
|
||||
Note that using lazy initialization when using Gemini is not necessary but recommended. If you don't use lazy initialization, you may get OOM error when initializing the model. If you use lazy initialization, you can avoid this error.
|
||||
|
||||
> ⚠ Lazy initialization support for tensor parallelism, pipeline parallelism, and auto-parallelism is still under development.
|
||||
|
||||
### Load from pretrained model
|
||||
|
||||
We should not load pretrained weight in `LazyInitContext`. If so, lazy initialization is meaningless, as the checkpoint is loaded and it takes much GPU memory. A recommended way is to initialize model from scratch in `LazyInitContext` and load pretrained weight outside `LazyInitContext` after calling `Booster.boost()`.
|
||||
|
||||
<!--- doc-test-ignore-start -->
|
||||
```python
|
||||
with LazyInitContext():
|
||||
model = GPT2LMHeadModel(config)
|
||||
|
||||
optimizer = ...
|
||||
lr_scheduler = ...
|
||||
dataloader = ...
|
||||
model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
||||
|
||||
booster.load_model(model, pretrained_path)
|
||||
```
|
||||
<!--- doc-test-ignore-end -->
|
||||
|
||||
As booster supports both pytorch-fashion checkpoint and huggingface/transformers-fashion pretrained weight, the `pretrained_path` of the above pseudo-code can be either a checkpoint file path or a pretrained weight path. Note that it does not support loading pretrained weights from network. You should download the pretrained weight first and then use a local path.
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 lazy_init.py -->
|
Loading…
Reference in new issue