To demonstrate the use of MaxWEnt with TensorFlow, we consider a two-dimensional classification problem where the input space, $\mathcal{X}$ lie in $\mathbb{R}^2$, and the output (label) space, $\mathcal{Y}$, is composed of two classes, such that $\mathcal{Y} = \{0, 1\}$.
The training instances are drawn according to the two-moons classification problem from scikit-learn. Please install scikit-learn to run this tutorial.
import sys
sys.path.append("../../")
Setup
import numpy as np
import matplotlib.pyplot as plt
from maxwent import classification_2d, plot_classification_2d
Below is an illustration of the problem. The training data are concentrated in the center of the input space, leaving a significant portion of $\mathcal{X}$ outside the distribution. In other words, a large regio of the input space remains uncovered by the training data.
x_train, y_train, x_ood = classification_2d()
ax = plot_classification_2d(x_train, y_train, x_ood)
ax.legend(loc="upper right"); plt.show()
Base Network
Next, we define the neural network architecture. We use a fully connected network consisting of two dense layers with ReLU activation functions. Since this is a binary classification task, the output layer has a sigmoid activation function.
import tensorflow as tf
base_net = tf.keras.Sequential()
base_net.add(tf.keras.layers.Input(shape=(2,)))
base_net.add(tf.keras.layers.Dense(100))
base_net.add(tf.keras.layers.ReLU())
base_net.add(tf.keras.layers.Dense(100))
base_net.add(tf.keras.layers.ReLU())
base_net.add(tf.keras.layers.Dense(1, activation="sigmoid"))
base_net.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 100) │ 300 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ re_lu (ReLU) │ (None, 100) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 100) │ 10,100 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ re_lu_1 (ReLU) │ (None, 100) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (Dense) │ (None, 1) │ 101 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 10,501 (41.02 KB)
Trainable params: 10,501 (41.02 KB)
Non-trainable params: 0 (0.00 B)
We specify the loss function and optimizer for training. In this case, we use the binary cross-entropy (BCE) loss, which is a standard choice for classification problems.
base_net.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss="bce")
base_net.fit(x_train, y_train, epochs=250, verbose=0);
Now that the network has been trained on the training data, we visualize its predictions across the entire input space $\mathcal{X}$. The predicted values are shown in purple in the figure below.
y_pred = base_net.predict(x_ood, verbose=0)
ax = plot_classification_2d(x_train, y_train, x_ood)
unc = ax.scatter(x_ood[:, 0], x_ood[:, 1], c=y_pred, cmap="seismic")
plt.colorbar(unc, ax=ax, label='Predicted class')
ax.legend(loc="upper left"); plt.show()
As we can see, the model divides the input domain into two distinct regions, each assigned to one of the classes.
However, this decision relies on a strong extrapolation—assuming that each class naturally extends into its respective region, with class 0 occupying the upper part and class 1 the lower. In reality, alternative explanations exist. The true decision boundary could follow a more complex pattern, such as a twisted spiral, challenging the model’s assumptions. Another possibility is that out-of-distribution data do not belong to either class 0 or 1 but instead represent entirely new, unseen classes.
To quantify the uncertainty of out-of-distribution data classification, it is essential to have a reliable uncertainty estimator. A common approach is to use the classifier’s output, which ranges between 0 and 1, as a proxy for probability. From this, one can compute the entropy of the predictions at any given point in the input space, providing a measure of uncertainty. Below, we visualize this uncertainty using shades of blue.
y_pred = base_net.predict(x_ood, verbose=0).ravel()
uncertainties = -y_pred * np.log(y_pred + 1e-8) - (1 - y_pred) * np.log(1 - y_pred + 1e-8)
ax = plot_classification_2d(x_train, y_train, x_ood)
unc = ax.scatter(x_ood[:, 0], x_ood[:, 1], c=uncertainties, cmap="Blues")
plt.colorbar(unc, ax=ax, label='Uncertainty')
ax.legend(loc="upper left"); plt.show()
As observed, the estimated uncertainty is larger primarily in the region between the two classes, while the network assigns no uncertainty elsewhere. However, it is important to clarify that the uncertainty estimated from the network’s output is not specifically designed to capture epistemic uncertainty—i.e., the uncertainty arising from a lack of knowledge about the true underlying classification function. The most principle way to estimate epistemic uncertainty is through the use of an ensemble of classifiers. Due to differences in initialization and the stochastic nature of gradient descent, an ensemble of networks will emerge. The diversity in their predictions serves as an indicator of epistemic uncertainty, helping us assess the confidence of our model in regions with little to no training data.
Deep Ensemble
The following code builds an ensemble of five networks, each trained independently on the same training data.
deep_ens = []
for _ in range(5):
net = tf.keras.models.clone_model(base_net)
net.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss="bce")
net.fit(x_train, y_train, epochs=250, verbose=0);
deep_ens.append(net)
y_preds = [
net.predict(x_ood, batch_size=1000, verbose=0)
for net in deep_ens
]
The standard deviation of the predictions provide an estimatiom of the epistemic uncertainty.
uncertainties = np.std(y_preds, axis=0)
ax = plot_classification_2d(x_train, y_train, x_ood)
unc = ax.scatter(x_ood[:, 0], x_ood[:, 1], c=uncertainties, cmap="Blues")
plt.colorbar(unc, ax=ax, label='Uncertainty')
ax.legend(loc="upper left"); plt.show()
As we can see, the diversity of predictions remains quite limited, with a significant portion of the out-of-distribution data receiving low uncertainty estimates.
The core issue is that the ensemble lacks sufficient diversity. Despite variations in initialization and training, the networks tend to converge to similar solutions, limiting their ability to express different perspectives on uncertain regions.
This is precisely where MaxWent comes into play. Its objective is to maximize diversity within the ensemble, ensuring a broader exploration of possible predictions.
Maximum Weight Entropy
To achieve this, we first introduce stochasticity into the network by using the set_maxwent_model function. This function replaces every Dense layer in the previous network with a DenseMaxWEnt layer.
from maxwent import set_maxwent_model
stoch_net = set_maxwent_model(base_net)
stoch_net.summary()
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense_3_mwe (DenseMaxWEnt) │ (None, 100) │ 604 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ re_lu_2_mwe (ReLU) │ (None, 100) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_4_mwe (DenseMaxWEnt) │ (None, 100) │ 30,200 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ re_lu_3_mwe (ReLU) │ (None, 100) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_5_mwe (DenseMaxWEnt) │ (None, 1) │ 10,202 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 41,006 (160.18 KB)
Trainable params: 10,501 (41.02 KB)
Non-trainable params: 30,505 (119.16 KB)
Stochasticity
The DenseMaxWEnt layer behaves similarly to a regular Dense layer, with one key difference: during each forward pass, the weights are sampled randomly from a distribution. Specifically, each weight $W_{ij}$ of the weight matrix $W$ is drawn according to the following distribution:
$$ W_{ij} \sim \overline{W_{ij}} + \phi_{ij} \, z_{ij} $$
where $z_{ij}$ follows a standard distribution, either uniform or normal with a mean of 0. Here, $\overline{W}$ represents the weight matrix of the base network (base_net), which has been pretrained on the training data and remains frozen (i.e., non-trainable). The only trainable parameters are the $\phi_{ij}$, which are the variances of the weight distributions. Note that the matrix $\phi$ has the same shape as $\overline{W}$, equal to (input_dim, output_dim), with input_dim the number of neurons of the previous layer and output_dim the number of neurons of the current layer.
As we can see from the model summary, we now have ~40,000 parameters instead of ~10,000. However, only ~10,000 of these are trainable parameters: those being the variance parameters $\phi$. Among the non-trainable parameters are $\overline{W}$ (the frozen base network weights) and a matrix $V$, whose utility will be discussed later.
This stochastic network enables us to sample $z$ to generate different neural networks. The diversity of the resulting ensemble of networks is controlled by the variance parameters $\phi$. If $\phi = 0$, we simply sample the same network as base_net.
Conceptually, we can think of this process as sampling within a “ball” centered around the weights of our pretrained base network. To encourage diversity, we aim to maximize the variance parameters $\phi_{ij}$, but we must be cautious not to generate networks that fail to fit the data. In the extreme case where $\phi_{ij} \gg 1$ for any $i, j$, we risk sampling degenerate networks that do not perform well on the training data.
Objective Function
The MaxWEnt trainer below addresses this balance by optimizing the $\phi$ parameters with the following objective function:
$$ \mathbb{E}_z \left[ \sum_{(x, y) \in \mathcal{S}} \ell \left(y, h \left(x, \overline{W} + \phi \odot z \right) \right) \right] + \lambda \sum_{i, j} \log(\phi_{ij}^2) $$
Where:
- $\mathcal{S}$ is the training dataset.
- $\ell(. , .)$ is a loss function (the mean squared error in our case).
- $h(x , W )$ is the prediction of the neural network of weights $W$ for the input data $x \in \mathcal{X}$.
- $\odot$ is the element-wise product between two matrices.
- $z$ follows a standard multivariate distribution with independent components (uniform over [-1, 1] in this case)
The first term of the objective represents the average loss over the training data, which serves to constrain the variances $\phi_{ij}$ from growing too large in directions where perturbations to the weights would deteriorate the training performance. The second term is the entropy of the weight distribution, which encourages the $\phi_{ij}$ to grow as large as possible in directions where perturbations have minimal impact on the training loss. This dual objective ensures both diversity and fit to the training data.
Training
from maxwent import MaxWEnt
mwe = MaxWEnt(stoch_net, lambda_=1.)
The parameter lambda_ controls the trade-off between the two objectives. Increasing lambda_ encourages greater weight diversity, but it should not be set too high, as this may excessively deteriorate the model’s accuracy on the training data. The number of training epochs also plays a role in this process. Typically, a larger lambda_ can achieve higher weight diversity with fewer epochs, but beyond a certain point, the training loss may increase drastically. A practical heuristic is to increase both lambda_ and the number of epochs as much as possible without causing the training loss to become unstable. Smaller values of lambda_ result in a more gradual increase in weight entropy, which allows for greater flexibility in the final stages of training. However, this can also lead to a longer training time before a significant increase in weight entropy occurs.
mwe.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss="bce")
mwe.fit(x_train, y_train, epochs=2500, verbose=0);
Inference
We now make predictions using the MaxWEnt model. It’s important to note that each time predict is called, a new set of weights is sampled, resulting in different predictions with every call. To obtain consistent predictions across multiple calls, you can use the seed argument. Using a fixed seed is a good practice, as it ensures that the same network is used for each batch during prediction. Without it, a new network is sampled for each batch, which can lead to discontinuous predictions along the input space $\mathcal{X}$.
n_sample = 50
y_preds = [
mwe.predict(x_ood, batch_size=1000, seed=123+i)
for i in range(n_sample)
]
y_pred = np.mean(y_preds, axis=0)
uncertainties = -y_pred * np.log(y_pred) - (1 - y_pred) * np.log(1 - y_pred)
ax = plot_classification_2d(x_train, y_train, x_ood)
unc = ax.scatter(x_ood[:, 0], x_ood[:, 1], c=uncertainties, cmap="Blues")
plt.colorbar(unc, ax=ax, label='Uncertainty')
ax.legend(loc="upper right"); plt.show()
RuntimeWarning: divide by zero encountered in log RuntimeWarning: invalid value encountered in multiply
y_pred = mwe.predict_mean(x_ood, batch_size=1000, clip=None, n_sample=50)
uncertainties = -y_pred * np.log(y_pred) - (1 - y_pred) * np.log(1 - y_pred)
ax = plot_classification_2d(x_train, y_train, x_ood)
unc = ax.scatter(x_ood[:, 0], x_ood[:, 1], c=uncertainties, cmap="Blues")
plt.colorbar(unc, ax=ax, label='Uncertainty')
ax.legend(loc="upper right"); plt.show()
RuntimeWarning: divide by zero encountered in log RuntimeWarning: invalid value encountered in multiply
uncertainties = mwe.predict_std(x_ood, batch_size=1000, clip=None, n_sample=50).ravel()
ax = plot_classification_2d(x_train, y_train, x_ood)
unc = ax.scatter(x_ood[:, 0], x_ood[:, 1], c=uncertainties, cmap="Blues")
plt.colorbar(unc, ax=ax, label='Uncertainty')
ax.legend(loc="upper right"); plt.show()
Let's now make a little improvement by using the SVD-parameterization, for that we only need to fit the SVD matrices before fitting the model
stoch_net = set_maxwent_model(base_net)
mwe = MaxWEnt(stoch_net, lambda_=1.)
mwe.fit_svd(x_train)
mwe.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss="mse")
mwe.fit(x_train, y_train, epochs=2500, verbose=0);
uncertainties = mwe.predict_std(x_ood, batch_size=1000, clip=None, n_sample=50).ravel()
ax = plot_classification_2d(x_train, y_train, x_ood)
unc = ax.scatter(x_ood[:, 0], x_ood[:, 1], c=uncertainties, cmap="Blues")
plt.colorbar(unc, ax=ax, label='Uncertainty')
ax.legend(loc="upper right"); plt.show()