mindformers.modules.transformer.MoEConfig

class mindformers.modules.transformer.MoEConfig(expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None, group_wise_a2a=False, comp_comm_parallel=False, comp_comm_parallel_degree=2)[源代码]

The configuration of MoE (Mixture of Expert).

参数
  • expert_num (int) – The number of experts employed. Default: 1

  • capacity_factor (float) – The factor is used to indicate how much to expand expert capacity, which is >=1.0. Default: 1.1.

  • aux_loss_factor (float) – The factor is used to indicate how much the load balance loss (produced by the router) to be added to the entire model loss, which is < 1.0. Default: 0.05.

  • num_experts_chosen (int) – The number of experts is chosen by each token and it should not be larger than expert_num. Default: 1.

  • expert_group_size (int) – The number of tokens in each data parallel group. Default: None. This parameter is effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.

  • group_wise_a2a (bool) – Whether to enable group-wise alltoall communication, which can reduce communication time by converting part of inter communication into intra communication. Default: False. This parameter is effective only when model parallel > 1 and data_parallel equal to expert parallel.

  • comp_comm_parallel (bool) – Whether to enable ffn compute and communication parallel, which can reduce pure communicattion time by splitting and overlapping compute and communication. Default: False.

  • comp_comm_parallel_degree (int) – The split number of compute and communication. The larger the numbers, the more overlap there will be but will consume more memory. Default: 2. This parameter is effective only when comp_comm_parallel enable.

Supported Platforms:

Ascend GPU

实际案例

>>> from mindformers.modules.transformer import MoEConfig
>>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1,
...                        expert_group_size=64, group_wise_a2a=True, comp_comm_parallel=False,
...                        comp_comm_parallel_degree=2)