Predicting user churn for Sparkify

Afagustin7
7 min readAug 25, 2021

Project overview

As a capstone project for the Udacity Data Scientist Nanodegree I decided to try to predict user churn in a dataset provided by the fictional music streaming platform Sparkify

For this project we are given log data of this platform as a JSON file and we are asked to generate insights in order to predict which users will churn. Any platform is interested in detecting which users are likely to cancel the service so as to take action (offer discounts, promotions etc.) to retain them. But companies also have to be careful and avoid giving discounts to users that are not actually thinking about leaving the service.

Problem statement

Our objective is to learn from this dataset what behaviours can allow us to predict whether users will churn – meaning the user cancels the service. Therefore, we have a classification problem and a binary target variable. In order to make good predictions, I will extract the most relevant features from the data, and train a machine learning classifier.

Metrics

Therefore, we are not only interested in precision, which consists in ensuring we identify as many users susceptible to churn as possible, but also in recall (ensuring the users we identify are actually likely to churn). That is why the main metric I use to evaluate the models is the F1-score.

Data exploration and visualization

The data used for this article is 128 MB large and consists of 286,500 actions taken by 277 users identified (which are the rows) and 18 features. The information that the features convey is diverse: some of them inform about the action the user is doing, i.e. if it’s listening to a song or not, or about personal characteristics such as gender, location etc.

Additionally, it is important to notice that the data used here is a small subset of the available data since the size of the full dataset is 12 GB. However, I will use the Spark framework, and keep scalability in mind, to ensure the same code can be reused when using the full dataset.

Doing a bit of analysis we find a series of interesting facts about the data:

  • Several user identifications are blank, probably because the user has not registered yet or has cancelled the service. I removed them and obtained 225 users. But after the feature engineering, I end up with 215.
  • The location makes reference to the city and the state. Since there are too many cities I only keep the state.
  • The actions are recorded during a period of approximately 2 months.
  • 144 users have a paid plan and 71 use the service for free.
  • The variable “Page” is very important as it gives information about the user’s actions. In the following figure you can see the frequency in which every action is taken:

Users change songs very frequently, but they also thumb up a song or go to the Home page. The least frequent category is “CancellationConfirmation” which is used to calculate the churn of the user:

  • Of the 215 users, 50 have churn and the other 165 remain using the service.

Data preprocessing and Implementation

Once we know the dataset, it is time to add new insights and make new variables from the existing ones.

The workflow to process the data is the following:

  • load JSON dataset
  • basic cleaning: exclude entries with empty user ID, remove duplicates, clean the location variable (keep only states), format the time variable and assign the correct types to all the variables.
  • feature engineering: consisting in two steps, 1) extraction of new features and 2) preparation of the data for modelling.
  1. Most of the computations consisted in making aggregations at the user level since the rows describe actions and not users. For example, in one case I calculated the number of songs played in each session and then averaged this metric at the user level.

The resulting final variables are:

  • Gender
  • Platform
  • Level
  • Registered days
  • Number of unique artists
  • The average number of songs listened to per session
  • The average number of next songs action per session
  • The average number add to playlist action per session
  • Number of sessions

Many more variables can be calculated but I preferred to focus on what I think are the most important ones. For one, logistic regression does not work well with a large number of features, and additionally, the dataset used is not that big.

2. In order to prepare the features to be used for modelling I needed to apply a series of transformations. Regarding the categorical variables, I use StringIndexer to transform the strings to index numbers and then OneHotEncoder to create the dummy variables (one for each category). Regarding the continuous variables, I did not apply any transformation though I could have normalised them. However, the ML algorithms I used are not affected by that. Lastly, I assembled together all the features (using Spark’s terminology).

  • Finally, I save these features to a CSV file to avoid having to repeat the above steps and retrieve them when needed

I partition the dataset in the training set (80% of the sample) and the test set (20%) and fit three binary classification models: Logistic regression, Random Forest and Gradient Boosting; using Spark’s default hyperparameters for the three algorithms. There were no major complications in this part of the implementation, the only issue was to be careful about the encoding of the categorical variables.

Refinement

To try to improve the initial results I used hyperparameter tuning with the Random Forest algorithm as I thought was the algorithm that could improve the most. In order to do that I use grid search-based model selection using Spark’s ParamGridBuilder function with two hyperparameters: the number of trees and maximum depth.

numTrees is the number of decision trees classifiers used and that will be later on assembled to form the Random Forest. The set used is numTrees=[20,40].

MaxDepth represents the depth of each tree in the forest. The deeper the tree, the more splits it has and it potentially captures more information about the data. The set used is maxDepth=[4,8].

I use 2-fold cross-validation to decide which set of hyperparameters achieve the best results.

Model evaluation and validation

The results are reasonably good, the models have notorious predictive value as measured by the f1 scores and the AUC. The f1 score of the logistic regression is 0.80, close to the highest possible value of 1.0, which would indicate perfect precision and recall.

Additionally, I have analysed the coefficients of the logistic regression and the feature importances produce by the Random Forest. The three features with the highest feature importance are the number of registered days, the number of unique artists that the user has listened to, and the number of sessions.

Justification of the chosen model

After trying to improve the existing algorithms with hyper-tuning, I decide that the best model is the logistic regression, for several factors:

  • It is the simplest model and thus the easiest to understand. The coefficients can be interpreted.
  • The metrics are quite good, it has the highest F1 score (80.32%).
  • Next, I plotted the ROC curve the see the relation between the TPR and the FPR.

Additionally, I use 4-fold cross-validation to assess in more detail how robust are the predictions to variations in the training data. I observe similar results in all the partitions: the f1 score ranges from 75 to 83 %.

Reflection

This capstone project is a great exercise allowing to put in practice several data science skills: data analysis, cleansing, feature extraction, machine learning pipeline creation, model evaluation and fine-tuning etc., to solve a problem familiar to many businesses.

In this project, I found interesting a couple of things. First, all the processes had to be coded with scalability in mind, since the full dataset is considerable big (12GB). This was an opportunity to use Spark and some of its packages such as spark.ml or spark.sql. Secondly, the project made me think about the kind of data and features that web platforms have and use.

Improvement

A list of things that remain to be done. Here I comment just some of them:

  • First and most important, I still have to process the complete dataset using a Spark cluster. The results are obtained from a subset of the dataset.
  • Improve tuning of the random forest and gradient boosting, try more hyperparameters.
  • Assess how unbalanced is the data. In practice, during a given period very few users churn with respect to the total population.
  • Lastly, we should pay more attention to the feature engineer step, since many more features can be created.

--

--