cjyuResearch

Sparse Mixture of Experts, Dropout, and Model Robustness

Published on: 2-27-2025

Brief overview of SMoE and the family of robustness techniques

ML models face tradeoffs between complexity and simplicity in several different dimensions. Adversial attack robustness. Adversial attacks are most often defined as white box attacks where the attacker has full access to the internals of a model. This allows them to directly compute the gradient to find adversarial attack inputs, creating attacks that are indistinguishable from regular inputs.
Overfitting during training. Models that are too complex have too many parameters for their given dataset and will overfit to the training data more easily. Stopping training earlier could prevent overfitting and lower generalization error, but this is inefficient, as it would be better to fully train a simpler model instead. This leads to poor generalization and out of distribution performance.
Out of distribution performance. This ties in closely with overfitting. The general risk decomposition of any machine learning model can be described as:
Expected Risk = Approximation error + Generalization error + Training Error
The approximation error describes the gap between the best model in the hypothesis class of the model and the Bayes optimal f*. This is the gap between the perfect version of your model and the perfect version of any Bayesian model. The generalization error describes the problem caused by limited training data. This manifests within our models as overfitting, where our model discovers spurleous correlations in our subset of the true training data distribution and reaches false conclusions. The training error describes the gap between the best version of your model and the model you have.
Complex models minimize approximation error. Larger training datasets minimize generalization error. Longer training times minimize training error. These are the three pillars of creating an accurate model. Training for too long causes overfitting. Overfitting lowers training error at the cost of generalization error. If your model is overfitting during training without achieving the desired performance, you need more data. If adding more data starts giving diminishing returns, you need a more complex model. If overfitting occurs too early, you need a simpler model.
Linear Diffusion Distributions
There are several techniques to balance these two competing interests. In this article, we review a few prominent techniques.
At the input level, we have data smoothing. Data smoothing is a collection of techniques that normalize data points across batches for improved training stability and inference quality. The most common case of this is batch normalization, where network weights are normalized to a Gaussian distribution across mini-batches. This technique is agnostic of input or network dimensions. CNNs, RNNs, LSTMs, MLPs, any model architecture can use it. This contrasts with other techniques like exponential smoothing and gaussian smoothing which are applied to series and image data respectively. These techniques operate directly on the data, dampening its signal to reduce the effectiveness of an adversarial attack and minimizing the impact of edge cases.
At the training level, we have the fast gradient sign method. This is a defense to adversial attacks that directly adds an adversarial weakness term to the loss function. The model works to minimize the maximum the performance could drop through a minor input perturbation.
At the neuron level, we have dropout and smoother activation functions. Dropout is a technique that randomly drops out individual neurons during training, building redundancy within the model. Multiple neurons learn to build the same intermediate representation instead of relying on a narrow combination of activations. This empirically improves training stability and robustness. Another technique to improve robustness is using smoother activation functions. The most common activation function, ReLU, has a sharp kink at the y-intercept. This leads to non-smooth loss landscapes that make the model less stable during training and more prone to adversarial attack. At the individual network level, we have paradigms like sparse mixture of experts. Mixture of experts is a network architecture where blocks of a neural network are strategically sectioned into “expert blocks”. This allows them to be selectively activated for each token/input by a routing network. Sparse mixture of experts takes this a step further, using the routing networks to select the top-k expert blocks and “dropout” the rest of the expert blocks in a layer. This minimizes training and inference cost by ignoring expert blocks who are the smallest contributors.
At the network level, we have ensemble techniques. To create an ensemble, we combine several models together and have them “vote” on the expected result. Based on the capability of each model, we can additionally weigh each model’s vote so as to emphasize the more powerful models. The greater the diversity of models, the more robustness naturally forms. Adversarial attacks by design work best on individual networks. Two networks, even if they are identical in architecture and hyperparameters, can have vastly different behaviors depending on their initialization. Thus ensembles vastly increase the complexity of the loss landscape, making an adversarial attack much more difficult.
These are all the levels where robustness can be implemented.