mindformers.modules.transformer.MoEConfig¶
- class mindformers.modules.transformer.MoEConfig(expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None, group_wise_a2a=False, comp_comm_parallel=False, comp_comm_parallel_degree=2)[源代码]¶
The configuration of MoE (Mixture of Expert).
- Args:
expert_num (int): The number of experts employed. Default: 1 capacity_factor (float): The factor is used to indicate how much to expand expert capacity,
which is >=1.0. Default: 1.1.
- aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the
router) to be added to the entire model loss, which is < 1.0. Default: 0.05.
- num_experts_chosen (int): The number of experts is chosen by each token and it should not be larger
than expert_num. Default: 1.
- expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is
effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
- group_wise_a2a (bool): Whether to enable group-wise alltoall communication, which can reduce communication
time by converting part of inter communication into intra communication. Default: False. This parameter is effective only when model parallel > 1 and data_parallel equal to expert parallel.
- comp_comm_parallel (bool): Whether to enable ffn compute and communication parallel, which can reduce pure
communicattion time by splitting and overlapping compute and communication. Default: False.
- comp_comm_parallel_degree (int): The split number of compute and communication. The larger the numbers,
the more overlap there will be but will consume more memory. Default: 2. This parameter is effective only when comp_comm_parallel enable.
- Supported Platforms:
AscendGPU- Examples:
>>> from mindformers.modules.transformer import MoEConfig >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1, ... expert_group_size=64, group_wise_a2a=True, comp_comm_parallel=False, ... comp_comm_parallel_degree=2)