Expectation-Maximization (EM) algorithm is widely known for being used in clustering or segmentation problems. Despite its usefulness, not so many know how the algorithm works. In this post, I will present the internals of the EM algorithm. Then, I’ll walk you through an example where I’ll develop the EM algorithm from scratch. And finally, I’ll mention the most common issues you may face when applying the EM algorithm in your problem.
A user segmentation problem
Imagine an important football match is going to take place, and to celebrate this event, both teams set up tents not so far to each other where fans can grab some drinks and socialize before heading to the stadium. I’ll use the city of Madrid in my example just because I currently live there.
Assume fans gather from any direction around their team’s tent, but the tents’ location are unknown. Putting our data scientist hat, we can say fans of the same team make up a cluster and the center of the cluster is the estimated location of the tent.
To get the true location of the tents, we can start taking a wild guess. The EM algorithm will iteratively provide a better estimation of their location. We can divide this process in two steps:
- Expectation: Identify the team each fan supports given the distance that supporter is away from each of the tents.
- Maximization: Estimate the location where a team’s tent is set up according to its fans location, the volume of space that each group occupies in the city, and the proportion of fans belonging to each team.
When trying to identify the team each fan supports we have to deal with probabilities, in the sense that we say something like there’s an 80% chance this fan supports the blue team and 20% chance he or she supports the red team. At the end, every fan is assigned to the team for which that probability is higher
How are the locations of the teams’ tents estimated? According to the assumption we made previously it makes sense to use a Gaussian distribution to model its location.
EM from the ground up
Formally, the EM algorithm is an iterative method to estimate parameters of statistical models with maximum likelihood when you are given incomplete data. Going back to our example, we label data with:
- X as the observed variables, which are the location or map coordinates of all the fans
- Z as the latent variables, which represents the probability of a fan belonging to a specific team
We don’t know the true value of Z but instead we are going to predict it. This step will be an important one since parameter estimation requires knowing the pairs (X, Z). Briefly, the learning process goes as follows:
- Set the initial parameters for your model
- Estimate values of Z given the initial parameters of your model (E-step)
- Re-estimate the parameters of your model with complete data (X, Z) (M-step)
- Re-estimate the value of Z with the new estimation of parameters (E-step)
You can perform the steps 3 and 4 iteratively until you achieve convergence on the parameters of your model.
Gaussian mixture models
You can assume your data is distributed according to a Gaussian mixture model (GMM) when data is scattered forming a small but prominent clusters, like the ones you saw in the first figure. Each of these clusters are modeled as Gaussian distributions so for each there will be two parameters defining them: Its mean (the center of the cluster) and its variance (the dispersion of the clusters)
Our job will be to estimate those parameters (plus the proportions, called from now on mixing weights). We will need to use the following equations to update the parameters (M-step).
These are the mixing weights. In short, it is calculated as the number of fans of one team divided by the total number of fans. Next, we show how the mean is calculated:
The center of each cluster is calculated as the weighted average of the location of all the fans. Their weights r, are the probabilities that the fans belongs to the cluster for which we are calculating its center. The fans that belong to the other cluster will have a very small weight and their contribution to that location will be insignificant.
Now, to update the covariance matrix of each cluster we are going to need the location of the clusters’ center we have just obtained:
In the same vein as in the previous equation, it is just a weighted average of the covariance matrix of the location of supporters. Those supporters far away from the cluster’s center will make the covariance to increase but it will be balanced out with the probabilities r.
How is the E-step performed? This is its equation for the case of GMM:
Which can be interpreted as the posterior probability of Z, once we take into account the observed variables.
Let’s see this algorithm in action with an example, written in Tensorflow and Tensorflow Probability. We make a random initial guess of the clusters:
estimate_mu1 = tf.Variable([15.,11.]) estimate_sigma1 = tf.Variable([[2.,0.],[0.,2.]], dtype=tf.float32) estimate_mu2 = tf.Variable([0.,-1.]) estimate_sigma2 = tf.Variable([[1.2,0.],[0.,1.2]], dtype=tf.float32)
We run this algorithm for 10 iterations:
It eventually models correctly both groups. Depending on the number of fans and the proximity between the tents’ location the algorithm may take more or less time to converge.
Issues when modeling GMM
The number of clusters to choose depends on the problem at hand but it can be initialized with the K-means algorithm. Many ready-to-use libraries use some sort of this method before running the EM algorithm.
Other issues you may run into are that the initial settings of the GMM may lead to different results. It happens more frequently if the problem requires to fit more than two clusters, the distance between clusters is very short or data is not normally distributed, and even you want to fit a cluster with very low variance (a cluster of composed of very few points)
All the popular implementation should be able to handle all these issues in some way or another. By the way, there’s no need to mention you should use one of those versions instead of my implementation, which by the way you can always find it on my Github repo. Up there, you will find a second case where the initial parameters makes apparently more difficult to model both clusters.