**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py).
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified.
-`attr_dict` is dict contains all the attributes need to be modified in this layer.
-`param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer.
3. Overwrite `inject_policy` method (Optional)
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
- The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details.
This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model.
- reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed.
- n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight.
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.
This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists.
This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row.
In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class.
These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions.
In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach.
This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else.
This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model.
In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$
alse can be represented as:
$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$
- Step
- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large
- Get a mask to mask the logits not in the local device
- Caculate the loss according to the second formula