Survival Evaluation: Leveraging Deep Learning for Time-to-Event Forecasting

-

Illustration by the creator

Practical Application to Rehospitalization

Survival models are great for predicting the time for an event to occur. These models could be utilized in a wide selection of use cases including predictive maintenance (forecasting when a machine is prone to break down), marketing analytics (anticipating customer churn), patient monitoring (predicting a patient is prone to be re-hospitalized), and way more.

By combining machine learning with survival models, the resulting models can profit from the high predictive power of the previous while retaining the framework and typical outputs of the latter (resembling the survival probability or hazard curve over time). For more information, take a look at the primary article of this series here.

Nonetheless, in practice, ML-based survival models still require extensive feature engineering and thus prior business knowledge and intuition to steer to satisfying results. So, why not use deep learning models as an alternative to bridge the gap?

Objective

This text focuses on how deep learning could be combined with the survival evaluation framework to unravel use cases resembling predicting the likelihood of a patient being (re)hospitalized.

After reading this text, you’ll understand:

  1. How can deep learning be leveraged for survival evaluation?
  2. What are the common deep learning models in survival evaluation and the way do they work?
  3. How can these models be applied concretely to hospitalization forecasting?

This text is the second a part of the series around survival evaluation. When you are usually not accustomed to survival evaluation, it’s best to start out by reading the primary one here. The experimentations described within the article were carried out using the libraries scikit-survival, pycox, and plotly. Yow will discover the code here on GitHub.

1.1. Problem statement

Let’s start by describing the issue at hand.

We’re curious about predicting the likelihood that a given patient shall be rehospitalized given the available details about his health status. More specifically, we would really like to estimate this probability at different time points after the last visit. Such an estimate is crucial to watch patient health and mitigate their risk of relapse.

It is a typical survival evaluation problem. The info consists of three elements:

Patient’s baseline data including:

  • Demographics: age, gender, locality (rural or urban)
  • Patient history: smoking, alcohol, diabetes mellitus, hypertension, etc.
  • Laboratory results: hemoglobin, total lymphocyte count, platelets, glucose, urea, creatinine, etc.
  • More information concerning the source dataset here.

A time t and an event indicator δ∈{0;1}:

  • If the event occurs throughout the commentary duration, t is the same as the time between the moment the info were collected and the moment the event (i.e., rehospitalization) is observed, In that case, δ = 1.
  • If not, t is the same as the time between the moment the info were collected and the last contact with the patient (e.g. end of study). In that case, δ = 0.
Figure 1 — Survival evaluation data, illustration by the creator. Note: patients A, and C are censored.

⚠️ With this description, why use survival evaluation methods when the issue is so much like a regression task? The initial paper gives a reasonably good explanation of the predominant reason:

“If one chooses to make use of standard regression methods, the right-censored data becomes a kind of missing data. It is often removed or imputed, which can introduce bias into the model. Subsequently, modeling right-censored data requires special attention, hence the usage of a survival model.” Source [2]

1.2. DeepSurv

Approach

Let’s move on to the theoretical part with just a little refresher on the hazard function.

“The hazard function is the probability a person is not going to survive an additional infinitesimal period of time δ, given they’ve already survived as much as time t. Thus, a greater hazard signifies a greater risk of death.”

Source [2]

Just like the Cox proportional hazards (CPH) model, DeepSurv relies on the idea that the hazard function is the product of the two functions:

  • the baseline hazard function: λ_0(t)
  • the danger rating, r(x)=exp(h(x)). It models how the hazard function varies from the baseline for a given individual given the observed covariates.

More on CPH models in the primary article of this series.

The function h(x) is usually known as the log-risk function. And that is precisely the function that the Deep Surv model goals at modeling.

In reality, CPH models assume that h(x) is a linear function: h(x) = β . x. Fitting the model consists thus in computing the weights β to optimize the target function. Nonetheless, the linear proportional hazards assumption doesn’t hold in lots of applications. This justifies the necessity for a more complex non-linear model that’s ideally able to handling large volumes of knowledge.

Architecture

On this context, how can the DeepSurv model provide a greater alternative? Let’s start by describing it. In accordance with the unique paper, it’s a “deep feed-forward neural network which predicts the consequences of a patient’s covariates on their hazard rate parameterized by the weights of the network θ.” [2]

How does it work?

‣ The input to the network is the baseline data x.

‣ The network propagates the inputs through plenty of hidden layers with weights θ. The hidden layers consist of fully-connected nonlinear activation functions followed by dropout.

‣ The ultimate layer is a single node that performs a linear combination of the hidden features. The output of the network is taken as the anticipated log-risk function.

Source [2]

Figure 2 — DeepSurv architecture, illustration by the creator, inspired by source [2]

In consequence of this architecture, the model may be very flexible. Hyperparametric search techniques are typically used to find out the variety of hidden layers, the variety of nodes in each layer, the dropout probability and other settings.

What concerning the objective function to optimize?

  • CPH models are trained to optimize the Cox partial likelihood. It consists of calculating for every patient i at time Ti the probability that the event has happened, considering all of the individuals still in danger at time Ti, after which multiplying all these probabilities together. Yow will discover the precise mathematical formula here [2].
  • Similarly, the target function of DeepSurv is the log-negative mean of the identical partial likelihood with a further part that serves to regularize the network weights. [2]

Code sample

Here’s a small code snippet to get an idea of how this kind of model is implemented using the pycox library. The entire code could be present in the notebook examples of the library here [6].

# Step 1: Neural net
# easy MLP with two hidden layers, ReLU activations, batch norm and dropout

in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = 1
batch_norm = True
dropout = 0.1
output_bias = False

net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
dropout, output_bias=output_bias)

model = CoxPH(net, tt.optim.Adam)

# Step 2: Model training

batch_size = 256
epochs = 512
callbacks = [tt.callbacks.EarlyStopping()]
verbose = True

model.optimizer.set_lr(0.01)

log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,
val_data=val, val_batch_size=batch_size)

# Step 3: Prediction

_ = model.compute_baseline_hazards()
surv = model.predict_surv_df(x_test)

# Step 4: Evaluation

ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
ev.concordance_td()

1.3. DeepHit

Approach

As an alternative of constructing strong assumptions concerning the distribution of survival times, what if we could train a deep neural network that might learn them directly?

That is the case with the DeepHit model. Specifically, it brings two significant improvements over previous approaches:

  • It doesn’t depend on any assumptions concerning the underlying stochastic process. Thus, the network learns to model the evolution over time of the connection between the covariates and the danger.
  • It may handle competing risks (e.g., concurrently modeling the risks of being rehospitalized and dying) through a multi-task learning architecture.

Architecture

As described here [3], DeepHits follows the common architecture of multi-task learning models. It consists of two predominant parts:

  1. A shared subnetwork, where the model learns from the info a general representation useful for all of the tasks.
  2. Task-specific subnetworks, where the model learns more task-specific representations.

Nonetheless, the architecture of the DeepHit model differs from typical multi-task learning models in two points:

  • It features a residual connection between the inital covariates and the input of the task-specific sub-networks.
  • It uses just one softmax output layer. Because of this, the model doesn’t learn the marginal distribution of competing events however the joint distribution.

The figures below show the case where the model is trained concurrently on two tasks.

The output of the DeepHit model is a vector y for each subject. It gives the probability that the topic will experience the event k ∈ [1, 2] for each timestamp t throughout the commentary time.

Figure 3 — DeepHit architecture, illustration by the creator, inspired by source [4]

2.1. Methodology

Data

The info set was divided into three parts: a training set (60% of the info), a validation set (20%), and a test set (20%). The training and validation sets were used to optimize the neural networks during training and the test set for final evaluation.

Benchmark

The performance of the deep learning models was in comparison with a benchmark of models including CoxPH and ML-based survival models (Gradient Boosting and SVM). More information on these models is offered in the primary article of the series.

Metrics

Two metrics were used to judge the models:

  • Concordance index (C-index): it measures the aptitude of the model to offer a reliable rating of survival times based on individual risk scores. It’s computed because the proportion of concordant pairs in a dataset.
  • Brier rating: It’s a time-dependent extension of the mean squared error to right censored data. In other words, it represents the common squared distance between the observed survival status and the anticipated survival probability.

2.2. Results

By way of C-index, the performance of the deep learning models is considerably higher than that of the ML-based survival evaluation models. Furthermore, there is sort of no difference between the performance of Deep Surval and Deep Hit models.

Figure 4 — C-Index of models on the train and test sets

By way of Brier rating, the Deep Surv model stands out from the others.

  • When examining the curve of the Brier rating as a function of time, the curve of the Deep Surv model is lower than the others, which reflects a greater accuracy.
Figure 5— Brier rating on the test set
  • This commentary is confirmed when considering the combination of the rating over the identical time interval.
Figure 6 — Integrated Brier rating on the test set

Note that the Brier wasn’t computed for the SVM as this rating is barely applicable for models which are in a position to estimate a survival function.

Figure 7— Survival curves of randomly chosen patients using DeepSurv Model

Finally, deep learning models could be used for survival evaluation in addition to statistical models. Here, for example, we are able to see the survival curve of randomly chosen patients. Such outputs can bring many advantages, particularly allowing a simpler follow-up of the patients which are essentially the most in danger.

✔️ Survival models are very useful for predicting the time it takes for an event to occur.

✔️ They may also help address many use cases by providing a learning framework and techniques in addition to useful outputs resembling the probability of survival or the hazard curve over time.

✔️ They’re even indispensable in this kind of uses cases to take advantage of all the info including the censored observations (when the event doesn’t occur throughout the commentary period for instance).

✔️ ML-based survival models are inclined to perform higher than statistical models (more information here). Nonetheless, they require high-quality feature engineering based on solid business intuition to attain satisfactory results.

✔️ That is where Deep Learning can bridge the gap. Deep learning-based survival models like DeepSurv or DeepHit have the potential to perform higher with less effort!

✔️ Nevertheless, these models are usually not without drawbacks. They require a big database for training and require fine-tuning multiple hyperparameters.

[1] Bollepalli, S.C.; Sahani, A.K.; Aslam, N.; Mohan, B.; Kulkarni, K.; Goyal, A.; Singh, B.; Singh, G.; Mittal, A.; Tandon, R.; Chhabra, S.T.; Wander, G.S.; Armoundas, A.A. An Optimized Machine Learning Model Accurately Predicts In-Hospital Outcomes at Admission to a Cardiac Unit. Diagnostics 2022, 12, 241.

[2] Katzman, J., Shaham, U., Bates, J., Cloninger, A., Jiang, T., & Kluger, Y. (2016). DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network, ArXiv

[3] Laura Löschmann, Daria Smorodina, Deep Learning for Survival Evaluation, Seminar information systems (WS19/20), February 6, 2020

[4] Lee, Changhee et al. DeepHit: A Deep Learning Approach to Survival Evaluation With Competing Risks. AAAI Conference on Artificial Intelligence (2018).

[5] Wikipedia, Proportional hazards model

[6] Pycox library

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

1 COMMENT

0 0 votes
Article Rating
guest
1 Comment
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

1
0
Would love your thoughts, please comment.x
()
x