You are on page 1of 20

HO ME NEWS CO NT RIBUT O RS RES EARCH CO NTACT

DEEP LEARNING FOR SURVIVAL ANALYSIS

B Y SEMINAR INFO R MATIO N SYSTEMS (WS19/20) | FEB R UAR Y 6, 2020 CATEGORIES

Deep Learning for Survival Analysis course-projects (27)

Authors: Laura Löschmann, Daria Smorodina instruction (2)

TAGS
Table of content
 AL B E R T
1. Motivation - Business case
2. Introduction to Survival Analysis  ATTE NTI ON
2.1 Common terms
2.2 Survival function  AW D- L STM
2.3 Hazard function
3. Dataset  B AYE SI AN- DE E P -
4. Standard methods in Survival Analysis L E AR NI NG
4.1 Kaplan - Meier estimator
4.2 Cox proportional hazards model  B AYE SI AN- TOP I C-
4.3 Time-varying Cox regression M ODE L L I NG
4.4 Random survival forests
5. Deep Learning for Survival Analysis  BERT  BILM
5.1 DeepSurv
5.2 DeepHit  B I NAR Y
6. Evaluation
6.1 Concordance index  B L ACK- B OX
6.2 DeepSurv - Survival curves
6.3 DeepHit - Hazard graphs  B L OCK CHAI N

7. Conclusion
8. References  CL ASS17/18

 CL ASS18/19

1. Motivation - Business case  CL ASS19/20

 CL ASSI F I CATI ON
With the financial crisis hitting the United States and Europe in 2008, the International Accounting Standards
Board (IASB) decided to revise their accounting standards for financial instruments, e.g. loans or mortgages to
 CNN  CONVE R SI ON
address perceived deficiencies which were believed to have contributed to the magnitude of the crisis.The result
was the International Financial Reporting Standard 9 that became effective for all financial years beginning on
 CONVOL UTI ONAL -
or after 1 January 2018 [1].
NE UR AL - NE TW OR K S
Previously impairment losses on financial assets were only recognised to the extent that there was an objective
evidence of impairment, meaning a loss event needed to occur before an impairment loss could be booked [2].  CR E DI T- R I SK

The new accounting rules for financial instruments require banks to build provisions for expected losses in their
loan portfolio. The loss allowance has to be recognised before the actual credit loss is incurred. It is a more  DE E P - L E AR NI NG

forward-looking approach than its predecessor with the aim to result in a more timely recognition of credit losses
 DE E P L E AR NI NG
[3].

To implement the new accounting rules banks need to build models that can evaluate a borrower’s risk as  DI STANT- TR ANSF E R -
accurately as possible. A key credit risk parameter is the probability of default. Classification techniques such as L E AR NI NG
logistic regression and decision trees can be used in order to classify the risky from the non-risky loans. These
classification techniques however do not take the timing of default into account. With the use of survival analysis  DOC2VE C
more accurate credit risks calculations are enabled since these analysis refers to a set of statistical techniques
that is able to estimate the time it takes for a customer to default.  DOCUM E NT- E M B E DDI NGS

 E CONOM I CUNCE R TAI NTY

2. Introduction to Survival Analysis  ELMO

Survival analysis also called time-to-event analysis refers to the set of statistical analyses that takes a series of  E M B E DDI NGS

observations and attempts to estimate the time it takes for an event of interest to occur.
 E X P L ANATI ON
The development of survival analysis dates back to the 17th century with the first life table ever produced by
English statistician John Graunt in 1662. The name “Survival Analysis” comes from the longstanding application  FASTTE X T
of these methods since throughout centuries they were solely linked to investigating mortality rates. However,
during the last decades the applications of the statistical methods of survival analysis have been extended  F I NE - TUNI NG
beyond medical research to other fields [4].
 GL OVE  GP T- 2
Survival Analysis can be used in the field of health insurance to evaluate insurance premiums. It can be a useful
tool in customer retention e.g. in order to estimate the time a customer probably will discontinue its subscription.
 GR U
With this information the company can intervene with some incentives early enough to retain its customer. The
accurate prediction of upcoming churners results in highly-targeted campaigns, limiting the resources spent on
 HI E R AR CHI CAL -
customers who likely would have stayed anyway. The methods of survival analysis can also be applied in the field
NE TW OR K
of engineering, e.g. to estimate the remaining useful life of machines.
 I CE
2.1 Common terms  I M AGE - ANALYSI S

 I M AGE - CAP TI ONI NG


Survival analysis is a collection of data analysis methods with the outcome variable of interest time to event. In
general event describes the event of interest, also called death event, time refers to the point of time of first
 I M B AL ANCE D- DATA
observation, also called birth event, and time to event is the duration between the first observation and the time
the event occurs [5]. The subjects whose data were collected for survival analysis usually do not have the same
 I NF E R E NCE  I TE
time of first observation. A subject can enter the study at any time. Using durations ensure a necessary
relativeness [6]. Referring to the business case the birth event is the initial recognition of a loan, the death event,
 K E R AS- I M DB - DATASE T
consequently the event of interest, describes the time a customer defaulted and the duration is the time between
the initial recognition and the event of default.
 L ANGUAGE - M ODE L
During the observation time not every subject will experience the event of interest. Consequently it is unknown if
the subjects will experience the event of interest in the future. The computation of the duration, the time from the  L ANGUAGE - M ODE L I NG
first observation to the event of interest, is impossible. This special type of missing data can emerge due to two
reasons:  L ANGUAGE - M ODE L L I NG

1. The subject is still part of the study but has not experienced the event of interest yet.
 L DA  LIME
2. The subject experienced a different event which also led to the end of study for this subject.

In survival analysis this missing data is called censorship which refers to the inability to observe the variable of  L ONG- SHOR T- TE R M -
M E M OR Y
interest for the entire population. However, the censoring of data must be taken into account, dropping
unobserved data would underestimate customer lifetimes and bias the results. Hence the particular subjects are
labelled censored.  L STM

Since for the censored subjects the death event could not be observed, the type of censorship is called right  M ONTE - CAR L O- DR OP OUT
censoring which is the most common one in survival analysis. As opposed to this there is left censoring in case
the birth event could not be observed.  M OVI E - R E VI E W S

The first reason for censored cases regarding the use case are loans that have not matured yet and did not
 NE UR AL - NE TW OR K
experience default at the moment of data gathering.

The second reason for censorship refers to loans that did not experience the event of default but the event of  NE UR AL - NE TW OR K S
early repayment. With this the loan is paid off which results in the end of observation for this loan. This kind of
censoring is used in models with one event of interest [7].  NL P  NN

In terms of different application fields an exact determination of the birth and death event is vital. Following there
 OVE R SAM P L I NG
are a few examples of birth and death events as well as possible censoring cases, besides the general censoring
case that the event of interest has not happened yet, for various use cases in the industry:
 P DP  P R E TR AI NI NG

 R E COM M E NDATI ON

 R E COM M E NDE R - SYSTE M

 R E COM M E NDE R -
SYSTE M S

 R NN  R OB E R TA

 SE NTI M E NT- ANALYSI S

 SE NTI M E NT-
CL ASSI F I CATI ON

 SE Q2SE Q

2.2 Survival function  SHAR E - P R I CE -


P R E DI CTI ON

The set of statistic methods related to survival analysis has the goal to estimate the survival function from
 SI M P L E TR ANSF OR M E R S
survival data. The survival function S(t) defines the probability that a subject of interest will survive beyond time t,
or equivalently, the probability that the duration will be at least t [8]. The survival function of a population is
 SUR VI VAL - ANALYSI S
defined as follows:

S(t) = Pr(T > t)  TE X T- ANALYSI S

T is the random lifetime taken from the population under study and cannot be negative. With regard to the  TE X T- CL ASSI F I CATI ON
business case it is the amount of time a customer is able to pay his loan rates, he is not defaulting. The survival
function S(t) outputs values between 0 and 1 and is a non-increasing function of t.  TE X T- GE NE R ATI ON

At the start of the study (t = 0), no subject has experienced the event yet. Therefore the probability S(0) of
 TE X T- M I NI NG
surviving beyond time zero is 1. S(∞) = 0 since if the study period were limitless, presumably everyone eventually
would experience the event of interest and the probability of surviving would ultimately fall to 0. In theory the
 TE X T- SUM M AR I ZATI ON
survival function is smooth, in practice the events are observed on a concrete time scale, e.g. days, weeks,
months, etc., such that the graph of the survival function is like a step function [9].
 TI M E - SE R I E S

 TI M E - SE R I E S-
F OR E CASTI NG

 TOX I C- COM M E NTS

 TR ANSF E R - L E AR NI NG

 TR ANSF OR M E R S

 TW I TTE R  UL M F I T

 UNCE R TAI NTY

 UP L I F T

 UP L I F T- M ODE L I NG

 VAR I ATI ONAL -


(Source: [9a]) I NF E R E NCE
2.3 Hazard function  W I K I TE X T- 103

 W OR D- E M B E DDI NGS
Derived from the survival function the hazard function h(t) gives the probability of the death event occurring at
time t, given that the subject did not experience the death event until time t. It describes the instantaneous
potential per unit time for the event to occur [10].

Pr(t ≤ T ≤ t + δt | T > t)
h(t) = lim
δt → 0
δt

Therefore the hazard function models which periods have the highest or lowest chances of an event. In contrast
to the survival function, the hazard function does not have to start at 1 and go down to 0. The hazard rate
usually changes over time. It can start anywhere and go up and down over time. For instance the probability of
defaulting on a mortgage may be low in the beginning but can increase over the time of the mortgage.

(Source: [10a])

The above shown graph is a theoretical example for a hazard function [11]. This specific hazard function is also
called bathtub curve due to its form. This graph shows the probability of an event of interest to occur over time.

It could describe the probability of a customer unsubscribing from a magazine over time. Within the first 30 days
the risk to unsubscribe is high, since the customer is testing the product. But if the customer likes the content,
meaning he “survives” the first 30 days, the risk of unsubscribing decreased and stagnates at lower level. After a
while the risk is increasing again since the customer maybe needs different input or got bored over time. Hence
the graph gives the important information when to initiate incentives for those customers whose risk to
unsubsribe is about to increase in order to retain them.

The main goal of survival analysis is to estimate and interpret survival and/or hazard functions from survival
data.

3. Dataset
We used the real-world dataset of 50.000 US mortgage borrowers which was provided by International Financial
Research (www.internationalfinancialresearch.org). The data is given as a “snapshot” in a panel format and
represents a collection of US residential mortgage portfolios over 60 periods. Loan can originate before the initial
start of this study and be paid after it will be finished as well.

When a person applies for mortgage, lenders (banks) want to know the value of risk they would take by loaning
money. In the given dataset we are able to inspect this process using the key information from the following
features:

Various timestamps for loan origination, future maturity and first appearance in the survival study.
Outside factors like gross domestic product (GDP) or unemployment rates at observation time.
Average price index at observation moment.
FICO score for each individual: the higher the score, the lower the risk (a “good” credit score is considered
to be in the 670-739 score range).
Interest rates for every issued loan.
Since our object of analysis is mortgage data we have some insights for inquired real estate types (home
for a single family or not, is this property in area with urban development etc.) which are also playing an
important role for prospective loan amount.

In order to use our data for survival analysis, we need to specify the characteristic terms. The birth event is the
time of the initial recognition of the mortgage, the death event is the default of the customer. The duration is the
time between the birth and death event. Some customers have not defaulted yet, so they will be labelled
“censored” in further analysis.

The graph below shows an example for the censorship concept at specific point in time (13 months).
Some customers defaulted before this point in time (red lines) and some “survived” beyond it (marked with blue
lines) and at this point in time it is unknown if these customers will experience the event of interest.

Handling this kind of the missing information is a main advantage of survival analysis. The distribution of the
event of interest (in graph below) shows that more than 2/3 of customers are labelled as “censored”. Dropping
out these observations would lead to a significant information loss and a biased outcome.

Survival analysis requires a specific dataset format:

E i is the event indicator such that E i = 1, if an event happens, and E i = 0 in case of censoring (column
default_time)
T i is the observed duration (total_obs_time column)
X i is a p−dimensional feature vector (covariates starting from the third column).

4. Standard methods in Survival Analysis


The standard ways for estimation can be classified into the three main groups: non-parametric, semi-
parametric, and parametric approaches. The choice which method to use should be guided by the dataset
design and the research question of interest. It is feasible to use more than one approach.

Parametric methods rely on the assumptions that the distribution of the survival times corresponds to
specific probability distributions. This group consists of methods such as exponential, Weibull and
lognormal distributions. Parameters inside these models are usually estimated using certain maximum
likelihood estimations.
In the non-parametric methods there are no dependencies on the form of parameters in underlying
distributions. Mostly, the non-parametric approach is used to describe survival probabilities as function of
time and to give an average view of individual’s population. The most popular univariate method is the
Kaplan-Meier estimator and used as first step in survival descriptive analysis (section 4.1).
To the semi-parametric methods corresponds the Cox regression model which is based both on
parametric and non-parametric components (section 4.2).

Generally, the range of available statistical methods which can be implemented in survival analysis is very
extensive and a selection of them is introduced in the scope of our blog post. The diagram below helps to briefly
familarize with them:
(Source: [18])

4.1 Kaplan - Meier estimator


The key idea of the Kaplan-Meier estimator is to break the estimation of the survival function S(t) into smaller
steps depending on the observed event times. For each interval the probability of surviving until the end of this
interval is calculated, given the following formula:

^ ni − di
S(t) = ∏ ni
,
i : t i <= t

where n i is a number of individuals who are at risk at time point t i and d i is a number of subjects that experienced
the event at time t i.

When using Kaplan-Meier estimator, some assumptions must be taken into account:

All observations - both censored and defaulted - are used in estimation.


There is no cohort effect on survival, so the subjects have the same survival probability regardless of their
nature and time of appearance in study.
Individuals who are censored have the same survival probabilities as those who are continued to be
examined.
The survival probability is equal for all subjects.

The main disadvantage of this method is that it cannot estimate survival probability considering all covariates in
the data (it is an univariate approach) which shows no individual estimations but the overall population survival
distribution. In comparison, semi- and parametric models allow to analyse all covariates and estimate S(t) with
respect to them.

1 from lifelines import KaplanMeierFitter


2 kmf = KaplanMeierFitter()
3
4 T = data_cox["total_obs_time"]
5 E = data_cox["default_time"]
6
7 kmf.fit(T, event_observed=E)

kmf.py hosted with ❤ by GitHub view raw

The estimated S(t) can be plotted as a stepwise function of overall population of individuals. As an example, in the
plot below, it is clear that for time t = 10 months the probability that borrowers survive beyond this time is about
75%.
4.2 Cox proportional hazards Model
The Cox proportional hazards model (CoxPH) involves not only time and censorship features but also additional
data as covariates (for our research all features of the dataset were used).

The Cox proportional hazards model (1972) is widely used in multivariate survival statistics due to a relatively
easy implementation and informative interpretation. It describes relationships between survival distribution and
covariates. The dependent variable is expressed by the hazard function (or default intensity) as follows:

This method is considered as semi-parametric: it contains a parametric set of covariates and a non-
parametric component λ 0(t) which is called baseline hazard , the value of hazard when all covariates
are equal to 0.
The second component are partial hazards or hazard ratio and they define the hazard effect of
observed covariates on the baseline hazard λ 0(t).
These components are estimated by partial likelihood and are time-invariant.
In general, the Cox model makes an estimation of log-risk function λ(t | x) as a linear combination of its
static covariates and baseline hazard.
Practical interpretation of Cox regression:
The sign of partial hazards (coef column) for each covariate plays an important role. A positive sign increases the
baseline hazard λ 0(t) and denotes that this covariate affects a higher risk of experiencing the event of interest. In
contrary, a negative sign means that the risk of the event is lower.

The essential component of the CoxPH is the proportionality assumption: the hazard functions for any two
subjects stay proportional at any point in time and the hazard ratio does not vary with time. As an example, if a
customer has a risk of loan default at some initial observation that is twice as low as that of another customer,
then for all later time observations the risk of defaulted loan remains twice as low.

Consequently, more important properties of the CoxPH can be derived:

The times when individuals may experience the event of interest are independent from each other.
Hazard curves of any individuals do not cross with each other.
There is a multiplicative linear effect of the estimated covariates on the hazard function.

However, for the given dataset this proportinality property does not hold due to a violation of some covariates.
Some additional methods can overcome this violation:

The first is binning these variables into smaller intervals and stratifying on them. We keep in the model the
covariates which do not obey the proportional assumption. The problem that can arise in this case is an
information loss (since different values are now binned together).
We can expand the time-varying data and apply a special type of Cox regression with continuous variables.
Random survival forests.
Extension with neural networks.

4.3 Time-varying Cox regression


Earlier, we assumed that predictors (covariates) are constant during the follow-up’s course. However, time-
varying covariates can be included in the survival models. The changes over time can be incorporated by using a
special modification of the CoxPH model.
This extents the personal time of individuals into intervals with different length. The key assumption of including
time-varying covariates is that its effect does not depend on time. Time-variant features should be used when it is
hypothesized that the predicted hazard depends significantly on later values of the covariate than the value of
the covariate at the baseline. Challenges with time-varying covariates are missing data in the covariate at
different timesteps. [15]

Before running the Cox regression model including new covariates it is necessary to pre-process the dataset into
so-called “long” format (where each duration is represented in start and stop view). [8]

Fitting the Cox model on modified time-varying data involves using gradient descent (as well as for standard
proportional hazard model). Special built-in functions in lifelines package take extra effort to help with the
convergence of the data (high collinearity between some variables). [8]

1 from lifelines import CoxTimeVaryingFitter


2
3 cox_tv = CoxTimeVaryingFitter()
4 cox_tv.fit(data_cox_tv, id_col="id", event_col="default_time", start_col="start", stop_col="
5
6 #Iteration 30: norm_delta = 0.00000, step_size = 1.00000, ll = -179884.48600,
7 #newton_decrement = 0.00000, seconds_since_start = 207.0
8 #Convergence completed after 30 iterations.

timecox.py hosted with ❤ by GitHub view raw

4.4 Random survival forests


Another feasible machine learning approach which can be used to avoid the proportional constraint of the Cox
proportional hazards model is a random survival forest (RSF). The random survival forest is defined as a tree
method that constructs an ensemble estimate for the cumulative hazard function. Constructing the ensembles
from base learners, such as trees, can substantially improve the prediction performance. [13]

Basically, RSF computes a random forest using the log-rank test as the splitting criterion. It calculates the
cumulative hazards of the leaf nodes in each tree and averages them in following ensemble.
The tree is grown to full size under the condition that each terminal node have no less than a prespecified
number of deaths. [18]
The out-of-bag samples are then used to compute the prediction error of the ensemble cumulative hazard
function.

Further technical implementation is based on scikit-survival package, which was built on top of scikit-learn: that
allows the implementation of survival analysis while utilizing the power of scikit-learn. [14]

1 from sklearn.preprocessing import OrdinalEncoder


2 from sklearn.model_selection import train_test_split
3 from sklearn.ensemble import RandomForestClassifier
4
5 from sksurv.preprocessing import OneHotEncoder
6 from sksurv.ensemble import RandomSurvivalForest
7
8 rstate = 124
9
10 # Split the data into train/test subsets
11 X_rf, y_rf = get_x_y_survival(data_cox, 'default_time', 'total_obs_time', 1)
12 X_rf_train, X_rf_test, y_rf_train, y_rf_test = train_test_split(X_rf, y_rf, test_size=0.25,
13
14 rsf = RandomSurvivalForest(n_estimators=50,
15 min_samples_split=7,
16 min_samples_leaf=10,
17 max_features="sqrt",
18 n_jobs=-1,
19 random_state=rstate,
20 verbose=1)
21
22 rsf.fit(X_rf_train, y_rf_train)

rf.py hosted with ❤ by GitHub view raw

Here is a simple example of building RSF to test this model on our survival data. Surely, hyperparameter tuning
can be applied for RSF in order to improve the accuracy metrics and the performance.

5. Deep Learning for Survival Analysis


Over the past years, a significant amount of research in machine learning has been conducted in combining
survival analysis with neural networks (the picture below helps to get an insight of this great scope of methods)
[18]. With the development of deep learning technologies and computational capacities it is possible to achieve
outstanding results and implement a range of architectures on sizeable datasets with different underlying
processes and more individual learning inside.

We can define particular groups of methods regading deep learning in survival analysis:
The first is based on further development of the baseline Cox proportional hazards model: DeepSurv
(section 5.1), Cox-nnet (extension of CoxPH on specific genetics datasets and regularizations). [16]
As an alternative approach, fully parametric survival models which use RNN to sequentially predict a
distribution over the time to the next event: RNN-SURV, Weibull Time-To-Event RNN etc. [17] [26]
On the other hand, there are some new advanced deep learning neural networks, such as DeepHit,
developed to also process the survival data with competing risks (section 5.2).

(Source: [18])

5.1 DeepSurv
The initial adaptation of survival analysis to meet neural networks (Farragi and Simon, 1995) was based on
generalization of the Cox proportional hazards model with only a single hidden layer. The main focus of the initial
model was to learn relationships between primary covariates and the corresponding hazard risk function.
Following development of the neural network architecture with Cox regression proved that in real-world large
datasets with non-linear interactions between variables it is rather complicated to keep the main proportionality
assumption of Cox regression model. However, Farragi and Simon’s network extended this non-linearity quality.
[25]

A few years ago, the more sophisticated deep learning architecture, DeepSurv, was proposed by J.L. Katzman et
al. as an addition to Simon-Farragi’s network. It showed improvements of the CoxPH model and the performance
metrics when dealing with non-linear data [12]. This architecture was able to handle the main proportional
hazards constraint. In addition to that, while estimating the log-risk function h(X) with the CoxPH model we used
the linear combination of static features from given data X and the baseline hazards. With DeepSurv we can also
drop this assumption out.

DeepSurv is a deep feed-forward neural network which estimates each individual’s effect on their hazard rates
with respect to parametrized weigths of the network θ. Generally, the structure of this neural network is quite
straightforward. Comparing to Simon-Farragi network, DeepSurv is a configurable with multiple number of
hidden layers.

The input data X is represented as set of observed covariates.


Hidden layers in this model are fully-connected nonlinear activation layers with not necessarily the same
number of nodes in each of them, followed by dropout layers.
The output layer has only one node with a linear activation function which gives the output ĥ θ (log-risk
hazard estimations).
Previously, the optimization of the classical Cox regression runs due to a optimization of the Cox partial
likelihood . This likelihood is defined with the following formula with parametrized weights β:

where t i, e i, x i are time, event, baseline covariate data in the i-th observation respectivelly. More explicitely, this is a
product of probabilities at the time t i for the i-th observation given the set of risk individuals (R) that are not
censored and have not experienced the event of interest before time t i.

The loss function for this network is a negative log partial likelihood L c(β) from the CoxPH (equation above)
with an additional regularization:

where λ is the l2 regularization parameter and N(e = 1) - set of the individuals with observable event.

In order to minimize the loss function with this regularization, it is necessary to maximize the part in the large
parentheses. For every subject i experiencing the event we increase the risk factor and censored objects j, who
have not experienced event before time t i should have a minimized risk.

Practical implementation:
To built the DeepSurv model we discovered two implentational options:

1. https://github.com/jaredleekatzman/DeepSurv - official repository from the discussed paper. However, the


packages inside were not updated recently and range of useful functions is not available.
2. https://github.com/havakv/pycox - based on PyTorch environment, computationaly fast approach to run
survival analysis models. This package is used for DeepSurv.

Firstly, we split survival dataset into train, test, validation subsets, then standardize the given data (only the
continuous variables) since our output layer is a linear Cox regression activation and convert these subsets into
arrays:

1 data_ds = data_cox.copy()
2 df_train = data_ds.copy()
3
4 df_test = df_train.sample(frac=0.2)
5 df_train = df_train.drop(df_test.index)
6 df_val = df_train.sample(frac=0.2)
7 df_train = df_train.drop(df_val.index)
8
9 cols_stand = ['balance_time', 'LTV_time', 'origination_time', 'maturity_time',
10 'interest_rate_time', 'house_price_index_time', 'gdp_time', 'unemployment_rate_time'
11 'balance_orig_time', 'FICO_orig_time',
12 'LTV_orig_time', 'interest_rate_orig_time', 'house_price_index_orig_time']
13
14 cols_leave = ['investor_orig_time', 'real_estate_condominium',
15 'real_estate_planned_urban_dev', 'real_estate_single_family_home', 'total_obs_time',
16
17 #standardize = [([col], StandardScaler) for col in cols_stand]
18
19 standardize = [([col], None) for col in cols_stand]
20 leave = [(col, None) for col in cols_leave]
21
22 x_mapper = DataFrameMapper(standardize + leave)
23 x_train = x_mapper.fit_transform(df_train).astype('float32')
24 x_val = x_mapper.transform(df_val).astype('float32')
25 x_test = x_mapper.transform(df_test).astype('float32')

deepsurv_data_preprocessing.py hosted with ❤ by GitHub view raw

Some transformations of the target variable with event and duration information:

1 get_target = lambda df: (df['total_obs_time'].values, df['default_time'].values)


2 y_train = get_target(df_train)
3 y_val = get_target(df_val)
4 durations_test, events_test = get_target(df_test)
5 val = x_val, y_val

deepsurv_target_variable.py hosted with ❤ by GitHub view raw

Building the Vanilla MLP with four hidden layers,


Batch normalization (for stabilization and reducing data noise),
Dropout 40% between the hidden layers,
ReLU were chosen as an optimal activation layer (alternatively, Scaled Exponentioal Linear Units (SELU)
can be implemented),
Adam optimizer was used for model training, without setting initial learning rate value.

1 # input settings for DeepSurv


2
3 n_nodes = 256
4 in_features = x_train.shape[1]
5 num_nodes = [n_nodes, n_nodes, n_nodes, n_nodes]
6 out_features = 1
7 batch_norm = True
8 dropout = 0.4
9 output_bias = False
10
11 net_ds = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout,
12
13 model_ = CoxPH(net_ds, tt.optim.Adam)

deepsurv_setting_NN.py hosted with ❤ by GitHub view raw

However, the learning rate was too high and, hence, we put a 0.001 value, in order to improve the performance:

1 batch_size = 128
2
3 # finding the best learning rate from this model
4 lrfinder = model_.lr_finder(x_train, y_train, batch_size, tolerance=10)
5 #_ = lrfinder.plot()
6 lrfinder.get_best_lr()
7
8 # setting the new number
9 model_.optimizer.set_lr(0.001)
10 model_.optimizer.param_groups[0]['lr']
11
12 epochs = 512
13
14 callbacks = [tt.callbacks.EarlyStopping()]
15 verbose = True
16
17 %%time
18 log = model_.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,
19 val_data=val, val_batch_size=batch_size)

deepsurv_running_the_model.py hosted with ❤ by GitHub view raw

The table below shows the set of hyperparameters used in the training and optimization. Since there was no
built-in hyperparameter search option in pycox package, this parameters were derived manually.

The final choice (lr = 0.001, batch_size = 128, number_nodes = 256) was based on the smallest loss value (it
equals -7.2678223). Comparing to the standard CoxPH (where the loss was ≈ -14.1) it is a significant
improvement.

5.2 DeepHit
The model called “DeepHit” was introduced in a paper by Changhee Lee, William R. Zame, Jinsung Yoon, Mihaela
van der Schaar in April 2018. It describes a deep learning approach to survival analysis implemented in a tensor
flow environment.

DeepHit is a deep neural network that learns the distribution of survival times directly. This means that this model
does not do any assumptions about an underlying stochastic process, so both the parameters of the model as
well as the form of the stochastic process depends on the covariates of the specific dataset used for survival
analysis. [18]

The model basically contains two parts, a shared sub-network and a family of cause-specific sub-networks. Due
to this architecture a great advantage of DeepHit is that it easily can be used for survival datasets with one single
risk but also with multiple competing risks. The dataset used so far describes one single risk, the risk of default.
Customers that did not experience the event of interest are censored. The reasons for censorship can either be
that the event of interest was not experienced or another event happened that also led to the end of observation,
but is not the event of interest for survival analysis.
The original dataset has information about a second risk, the early repayment, also called payoff. For prior use
the dataset was preprocessed in a way that customers with an early repayment were also labelled censored,
because the only event of interest was the event of default. If the second risk also becomes the focus of attention
in terms of survival analysis a second label for payoff (payoff = 2) can be introduced in the event column of the
dataset. Therefore a competing risk is an event whose occurrence precludes the occurrence of the primary event
of interest. [19]

The graph below shows the distribution of the target variable time to event within the dataset for competing risks.
In total more customers experience the event of payoff than face the event of default or become censored.
Throughout the observation time most of the customers who pay off early, repay their mortgage within the first
year. The proportion of the customers who default is also high within the first year. The amount of payoffs as well
as defaults per month decreases after the first year. Most of the censored customers are censored sometime
after 2.5 years besides a peak of censored customers at the ninth month.

To also handle competing risks DeepHit provides a flexible multi-task learning architecture. Multi-task learning
was originally inspired by human learning activities. People often apply the knowledge learned from previous
tasks to help learn a new task. For example, for a person who learns to ride the bicycle and unicycle together, the
experience in learning to ride a bicycle can be utilized in riding a unicycle and vice versa. Similar to human learning,
it is useful for multiple learning tasks to be learned jointly since the knowledge contained in a task can be
leveraged by other tasks. In the context of deep learning models, multiple models could be trained, each model
only learning one task (a). If this multiple tasks are related to each other, a multi-task learning model can be used
with the aim to improve the learning of a model by using the knowledge achieved throughout the learning of
related tasks in parallel (b). [20]

(Source: [20a])
Multi-task learning is similar to transfer learning but has some significant differences. Transfer learning models
use several source tasks in order to improve the performance on the target task. Multi-task learning models treat
all tasks equally, there is no task importance hierarchy. There is no attention focus on one specific task. The goal
of multi-task learning models is to improve the performance of all tasks.

The most commonly used approach to multi-task learning in neural networks is called hard parameter sharing.
The general architecture of such a multi-task learning model describes two main parts. The first part is a shared
sub-network, where the model learns the common representation of the related tasks. The model then splits into
task-specific sub-networks in order to learn the non-common parts of the representation. The number of task-
specific sub-networks is equal to the number of related tasks the model is trained on. For the sake of
completeness another approach to multi-task learning is soft parameter sharing that describes an architecture
where each task has its own model with its own parameters. To encourage the parameters to become similar,
regularisation techniques are applied between the parameters of the task-specific models. Since DeepHit
provides an architecture of hard parameter sharing, the approach of soft parameter sharing will be neglected in
further explanations.

(Source: [20b])

To train a multi-task learning model just as many loss functions as tasks are required. The model is then trained
by backpropagation. The fact that the task-specific sub-networks share common hidden layers, allows
comprehensive learning. Through the shared hidden layers, features that are developed in the hidden layers of
one task can also be used by other tasks. Multi-task learning enables features to be developed to support several
tasks which would not be possible if multiple singe-task learning models would be trained on the related tasks in
isolation. Also some hidden units can specialise on one task, providing information that are not important for the
other tasks. By keeping the weights to these hidden units small gives these tasks the opportunity to ignore these
hidden units. [21]

With multi-task learning a model can increase its performance due to several reasons. By using the data of
multiple related tasks, multi-task learning increases the sample size that is used to train the model which is a kind
of implicit data augmentation. The network sees more labels, even though these labels are not the labels from the
same task but highly related tasks. A model that learns different similar tasks simultaneously is able to learn a
more general representation that captures all of the tasks.
Moreover by learning multiple tasks together the network has to focus on important information rather than
task-specific noise. The other tasks provide additional evidence for the relevance or irrelevance of the features
and help to attract the network´s attention to focus on the important features.
Some tasks are harder to learn even by themselves. A model can benefit from learning the hard task combined
with an easier related task. Multi-task learning allows the model to eavesdrop, learn the hard task through the
simple related task, and therefore learn the hard task easier and faster than learning the hard task in isolation.
In addition different related tasks can treat each other as a form of regularisation term since the model has to
learn a general representation of all tasks. Learning the tasks in a single-task learning approach would bear the
risk of overfitting on one task. [22]

Architecture of the DeepHit model for two competing risks


(Source: [22a])

The architecture of the DeepHit model is similar to the conventional multi-task learning architecture of hard
parameter sharing, but has two main differences. DeepHit provides a residual connection between the original
covariates and the input of the cause-specific sub-networks. This means that the input of the cause-specific sub-
networks is not only the output of the preceded shared sub-network but also the original covariates. These
additional input allows the cause-specific sub-network to better learn the non-common representation of the
multiple causes.
The other difference refers to the final output of the model. DeepHit uses one single softmax output layer so that
the model can learn the joint distribution of the competing events instead of their marginal distribution. Thus the
output of the DeepHit model is a vector y for every subject in the dataset giving the probabilities that the subject
with covariates x will experience the event k for every timestamp t within the observation time. The probabilities of
one subject sum up to 1.

The visualisation of the DeepHit model shows the architecture for a survival dataset of two competing risks. This
architecture can easily be adjusted to more or less competing risks by adding or removing cause-specific sub-
networks. The architecture of the DeepHit model depends on the number of risks.
To implement the model the DeepHit repository has to be cloned to create a local copy on the computer.

DeepHit also needs the characteristic survival analysis input setting containing the event labels, the durations as
well as the covariates. A function is provided that either applies standardisation or normalization of the data. For
this analysis standardisation was applied on the data.

The variable num_Category describes the dimension of the time horizon of interest and is needed in order to
calculate the output dimension of the output layer of the model. num_Event gives the number of events excluding
the case of censoring, since censoring is not an event of interest. This number defines the architecture of the
model, it specifies the number of cause-specific sub-networks and is also needed to calculate the dimension of
the output layer, which is the multiplication of num_Category and num_Event. The input dimension is defined by
the number of covariates used to feed the network.

1 # DeepHit input settings


2
3 # Characteristic data format E, T, X
4 event = np.asarray(data[['event']])
5 time = np.asarray(data[['duration']])
6 dhdata = np.asarray(data.iloc[:,1:18])
7
8 # Standardisation of the data
9 dhdata = f_get_Normalization(dhdata, 'standard')
10
11 # Dimension of time horizon of interest (equivalent to the output dimension per risk)
12 # Factor of 1.2 was used in the paper in order to have enough time-horizon
13 num_Category = int(np.max(time) * 1.2)
14
15 # Number of events (censoring is not included)
16 num_Event = int(len(np.unique(event)) - 1)
17
18 # Input dimension
19 x_dim = np.shape(dhdata)[1]
20
21 # Based on the data, mask1 and mask2 needed to calculate the loss functions
22 # To calculate loss 1 - log-likelihood loss
23 mask1 = f_get_fc_mask2(time, event, num_Event, num_Category)
24 # To calculate loss 2 - cause-specific ranking loss
25 mask2 = f_get_fc_mask3(time, -1, num_Category)
26
27 DIM = (x_dim)
28 DATA = (dhdata, time, event)
29 MASK = (mask1, mask2)

deephit_input_settings.py hosted with ❤ by GitHub view raw

The hyperparameters of DeepHit can be tuned by running random search using cross-validation. The function
get_random_hyperparameters randomly takes values for parameters out of a manually predefined range for
those parameters. Possible candidates for parameter tuning can be:

Batch size
Number of layers for the shared sub-network
Number of layers for the cause-specific sub-network
Number of nodes for the shared sub-network
Number of nodes for the cause-specific sub-network
Learning rate
Dropout
Activation function

The chosen parameters are forwarded to the function get_valid_performance along with the event labels,
durations and covariates (summarized in DATA) as well as the masks for the loss calculations (summarized in
MASK). This function takes the forwarded parameters to build a DeepHit model corresponding to the number of
events of interest as well as the number of layers and nodes for the sub-networks. The dataset is then spilt into
training, validation and test sets in order to start training the model on the training set using the chosen
parameters. The training is done with mini batches of the training set over 50.000 iterations. Every 1000 iteration
a prediction is done on the validation set and the best model is saved to the specified file path. The evaluation of
the models is based on the concordance index. The best result (= highest concordance index) is returned if there
is no improvement for the next 6000 iterations (early stopping). The concordance index is a measure for survival
analyis models and is explained in detail in the evaluation part of this blog post.

1 # Hyperparameter tuning
2
3 # Number of training/validation/test splits during tuning
4 OUT_ITERATION = 1
5
6 # Number of random search iterations
7 RS_ITERATION = 20
8
9 # For saving purposes of the best parameters
10 data_mode = 'mortgage'
11 out_path = data_mode + '/results'
12
13 # Times when the validation is performed
14 eval_times = [4,8,18]
15
16 for itr in range(OUT_ITERATION):
17
18 if not os.path.exists(out_path + '/itr_' + str(itr) + '/'):
19 os.makedirs(out_path + '/itr_' + str(itr) + '/')
20
21 max_valid = 0
22 log_name = out_path + '/itr_' + str(itr) + '/hyperparameters_log.txt'
23
24 for r_itr in range(RS_ITERATION):
25 print('OUTER_ITERATION: ' + str(itr))
26 print('Random search... itr: ' + str(r_itr))
27 new_parser = get_random_hyperparameters(out_path)
28 print(new_parser)
29
30 # get validation performance given the hyperparameters
31 tmp_max = get_valid_performance(DATA, MASK, new_parser, itr, eval_times, MAX_VALUE=
32
33 if tmp_max > max_valid:
34 max_valid = tmp_max
35 max_parser = new_parser
36 save_logging(max_parser, log_name) # save the hyperparameters if they provide
37
38 print('Current best: ' + str(max_valid))

deephit_hyperparameter_tuning.py hosted with ❤ by GitHub view raw

DeepHit is build with Xavier initialisation and dropout for all the layers and is trained by back propagation via the
Adam optimizer. To train a survival analysis model like DeepHit a loss function has to be minimised that is
especially designed to handle censored data.

The loss function of the DeepHit model is the sum of two terms.

L 1 is the log-likelihood of the joint distribution of the first hitting time and event. This function is modified in a way
that it captures censored data and considers competing risks if necessary. The log-likelihood function also
consists out of two terms. The first term captures the event and the time, the event occurred, for the uncensored
customers. The second term captures the time of censoring for the censored customers giving the information
that the customer did not default up to that time.

L 2 is a combination of cause-specific ranking loss functions since DeepHit is a multi-task learning model and
therefore needs cause-specific loss functions for training. The ranking loss function incorporates the **
estimated cumulative incidence function ** calculated at the time the specific event occurred. The formula of the
cumulative incidence function (CIF) is as follows:

This function expresses the probability that a particular event k occurs on or before time t conditional on
covariates x. To get the estimated CIF, the sum of the probabilities from the first observation time to the time, the
event k occurred, is computed.

The cause-specific ranking loss function adapts the idea of concordance. A customer that experienced the event
k on a specific time t should have a higher probability than a customer that will experience the event sometime
after this specific time t. The ranking loss function therefore compares pairs of customers that experienced the
same event of interest and penalizes an incorrect ordering of pairs.

After the training process the saved optimised hyperparameters as well as the corresponding trained model can
be used for the final prediction on the test dataset.

1 for out_itr in range(OUT_ITERATION):


2 # Load the saved optimised hyperparameters
3 in_path = data_mode + '/results/'
4 in_hypfile = in_path + '/itr_' + str(out_itr) + '/hyperparameters_log.txt'
5 in_parser = load_logging(in_hypfile)
6
7
8 # Forward the hyperparameters
9 mb_size = in_parser['mb_size']
10
11 iteration = in_parser['iteration']
12
13 keep_prob = in_parser['keep_prob']
14 lr_train = in_parser['lr_train']
15
16 h_dim_shared = in_parser['h_dim_shared']
17 h_dim_CS = in_parser['h_dim_CS']
18 num_layers_shared = in_parser['num_layers_shared']
19 num_layers_CS = in_parser['num_layers_CS']
20
21 if in_parser['active_fn'] == 'relu':
22 active_fn = tf.nn.relu
23 elif in_parser['active_fn'] == 'elu':
24 active_fn = tf.nn.elu
25 elif in_parser['active_fn'] == 'tanh':
26 active_fn = tf.nn.tanh
27 else:
28 print('Error!')
29
30
31 initial_W = tf.contrib.layers.xavier_initializer()
32
33 # Weights for the loss functions that also can be tuned
34 # Default values: 1.0
35 alpha = in_parser['alpha'] #for log-likelihood loss
36 beta = in_parser['beta'] #for ranking loss
37
38
39 # Create the dictionaries
40 # For the input settings
41 input_dims = { 'x_dim' : x_dim,
42 'num_Event' : num_Event,
43 'num_Category' : num_Category}
44
45 # For the hyperparameters
46 network_settings = { 'h_dim_shared' : h_dim_shared,
47 'h_dim_CS' : h_dim_CS,
48 'num_layers_shared' : num_layers_shared,
49 'num_layers_CS' : num_layers_CS,
50 'active_fn' : active_fn,
51 'initial_W' : initial_W }
52
53
54 # Create the DeepHit network architecture
55 tf.reset_default_graph()
56
57 config = tf.ConfigProto()
58 config.gpu_options.allow_growth = True
59 sess = tf.Session(config=config)
60
61 model = Model_DeepHit(sess, "DeepHit", input_dims, network_settings)
62 saver = tf.train.Saver()
63
64 sess.run(tf.global_variables_initializer())
65
66 # Training, test sets split
67 (tr_data,te_data, tr_time,te_time, tr_label,te_label,
68 tr_mask1,te_mask1, tr_mask2,te_mask2) = train_test_split(data, time, event, mask1, ma
69
70 # Restoring the trained model
71 saver.restore(sess, in_path + '/itr_' + str(out_itr) + '/models/model_itr_' + str(out_i
72
73 # Final prediction on the test set
74 pred = model.predict(te_data)

deephit_fin_prediction.py hosted with ❤ by GitHub view raw

6. Evaluation
6.1 Concordance index
For the evaluation of survival analysis models the performance measures need to take censored data into
account. The most common evaluation metric in survival analysis is the concordance index (c-index). It shows
the model’s ability to correctly provide a reliable ranking of the survival times based on the individual risk scores.
The idea behind concordance is that a subject that dies at time t should have a higher risk at time t than a subject
who survives beyond time t.

The concordance index expresses the proportion of concordant pairs in a dataset, thus estimates the
probability that, for a random pair of individuals, the predicted survival times of the two individuals have
the same ordering as their true survival times. A concordance index of 1 represents a model with perfect
prediction, an index of 0.5 is equal to random prediction. [23]

For a better understanding of this definition the concordance index is calculated on some simple example
predictions. The following table shows the true default times of four theoretical customers along with default
time predictions of three different models.

To calculate the concordance index the number of concordant pairs has to be divided by the number of possible
ones. By having four customers the following pairs are possible: (A,B) , (A,C) , (A,D) , (B,C) , (B,D) , (C,D). The total
number of possible pairs is 6.

Model 1 predicts that A defaults before B, and the true default time confirms that A defaults before B. The
pair (A,B) is a concordant pair. This comparison needs to be done for every possible pair. For the prediction
of Model 1 all possible pairs are concordant, which results in an Concordance index of 1 - perfect
prediction.
For the prediction of Model 2 there are five concordant pairs, but for the pair (C,D) the model predicts that
D defaults before C, whereas the true default times show that C defaults before D. With this the
concordance index is 0.83 (5/6).
The concordance index of Model 3 is also equal to 1, since the model predicts the correct order of the
possible pairs even though the actual default times are not right in isolation.

The next example shows the computation of the concordance index in case of right-censoring:

The first step is to figure the number of possible pairs. The default times of customer A can be compared to the
default times of the other customers. The customer B is censored, which means that the only information given is
the fact that customer B did not default up to time 2, but there is no information if customer B will default and if
so, when the customer will experience the event of default. Therefore a comparison between customer B and C as
well as customer B and D is impossible because these customers defaulted after customer B was censored. The
comparison between customers C and D is possible since both customers are not censored. In total there are
four possible pairs: (A,B) , (A,C) , (A,D), (C,D) The second step is to check if these possible pairs are concordant.
The first three pairs are concordant, the pair (C,D) is discordant. The result is a concordance index of 0.75 (3/4).
[24]

The dataset used for the blog post features the case of right-censoring but the reason for censoring is that these
customers are still in the phase of repaying and their loans have not matured yet. Therefore the time of censoring
is equal to the last observation time. Due to this the case that some customer default after a customer was
censored is not possible. The example of the concordance index in case of right-censoring is shown for the sake
of completeness since other survival datasets can have this case. A medical dataset for example can have data
about patients with a heart disease. If a patient dies due to different reasons than a heart disease this patient
would be censored. This can happen during the observation time and other patients can die due to a heart
disease at a later time.

Model evaluation
The table shows the concordance indices of the models trained with the mortgage dataset. The benchmark
models, CoxPH and Random survival forests, start with a convenient performance but are outperformed by the
deep learning models whereas the DeepHit model achieved the highest concordance index.

After evaluating the performance of the models we have a look into the output of the two best performing
models, DeepSurv and DeepHit.

6.2 DeepSurv - Survival curves


As we have already learned before in part 4.1 about Kaplan-Meier estimator, survival curve represents a
statistical graphical interpetation of the survival behaviour of subjects (i.e. mortgage borrowers) in the form of a
graph showing percentage surviving vs time. This allows to examine and compare estimated survival times for
each individual (except Kaplan-Meier model) and define global patterns in data (in example, sharp lines which go
close to 0% propability may have certain explaination).

The graph below represents the estimated survival lifetimes for 15 individual mortgage borrowers from the test
dataset using the output of the DeepSurv model. According to the graph, for a significant amount of customers
the predicted survival times decrease within the first two years. For instance, for the customer with ID 5 the
survival function shows that after 15 months he has a probability of roughly 50% to survive beyond 15 months.
Whereas the survival function of customer with ID 9 at the same point in time shows that he has only 25% chance
to survive beyond this time.

By the end of our study there is a certain flatten part at t ≈ 42 months for some number of customers. The
possible reason behind this can be due to provided individual “treatments” by the bank e.g. in order to reduce the
maturity time.

6.3 DeepHit - Hazard graphs


The output of the DeepHit model is a vector for every customer giving the probabilities of the customer
experiencing the event of interest for every point in time. The evaluation time is 72 months. Therefore the output
gives 72 probabilities for every customer experiencing the event of default (single risk). It is the joint distribution
of the first hitting time and event, hence the sum of the probabilities of a customer is equal to 1. The following
graph displays the visualisation of the output of every customer included in the test set (10.000 customers).
The graph shows that in the beginning there seems to be a higher risk of default which is decreasing within the
first two years which also matches to the predicted survival curves of the DeepSurv model. Throughout the
evaluation time there are several probability increases for individual customers. There is a higher risk of default
after the second and third year as well as within the period of the fifth and sixth year of credit time. Unfortunately
it is not possible to compare these specific times to actual events in the past to derive any reasons for these
peaks since the periods of the mortgage dataset used for this analysis are deidentified. Thus it cannot be
retraced when the data for this dataset was collected.

To get a closer look at the individual hazard graphs in order to compare the prediction of the model to the true
default times the hazard graphs of a selection of six customers is plotted.
For the most part the hazard graphs of these customers show that within the first year the probability of default
is higher and mostly decreasing within the second year.

Hazard graph 1 also represents this trend. Throughout the rest of the evaluation time the probability
values decrease and range between 0.5% and 2%. In the dataset the customer was censored after 26
months. With regard to the predicted hazard ratio if the customer “survives” beyond the first year he
probably does not experience the event of default afterwards.
Hazard graph 2 starts with a high default probability after 3 months. With respect to the actual values, the
customer defaulted after 3 months, the model could make a precise prediction.
Hazard graph 3 shows the highest values within the time of 10 and 13 months after initial recognition of
the mortgage which represents the actual values of the customer defaulting after 13 months.
Hazard graph 4 differs from the other graphs since it starts with low risk of default period. The probability
is not decreasing until the start of the sixth year of credit time except a little increase at the end of the
second year. The model predicts that if the customer will experience the event of default it will be sometime
after the fifth year of credit time. The customer was censored after 39 months, he is still repaying his
mortgage rates and has not experienced the event yet.
The customers of Hazard graph 5 and 6 were censored after a short time interval. They both have an
increased risk of default within the first year. For customer 5 the second and third year is a low risk period,
followed by years of higher risk of default. Hazard graph 6 shows a decrease in hazard after the second
year but like the Hazard Rate 1 and 3 the probabilities vary between low values until the end of evaluation
time.

In case of two competing risks the output of DeepHit is a vector of length 144 for every customer. This length
comes from 72 probabilities of experiencing event 1 (default) and 72 probabilities of experiencing event 2
(payoff). The vector gives the joint probability distribution of both events, so the sum of a vector of one customer
is equal to 1.

To get an overview of the predictions the output of every customer per event is visualised. When comparing the
graphs the different ranges of the probability of risk have to be noted. The first graph shows the hazard ratios of
the customer experiencing the event of default. In the beginning the risk of default is higher but decreases within
the first two years reaching a low risk period within the years three and four. After that period the probability to
default increases locally and for some individual customers the model predicts the highest risk of default after 5
years which is probably a result due to the censored data. The risk of payoff is compared to the default risk
higher in the beginning, since in total more customers experience the event of payoff than the event of default. It
is slightly decreasing throughout the first 4.5 years, for some customers the event of payoff is pretty likely within
the fifth year.

Looking at selected individual hazard graphs plotting the joint distribution of both events per customer to
compare the predictions with the true event times of the selected customers.
Hazard Graph 1 gives a higher probability of experiencing default than payoff. Moreover the model
predicts the default to be at the end of first year, which matches the true default time of the customer
experiencing the event of default after 13 months.
The Hazard Graph 2 starts with a low risk period of more than two years regarding both events. After 2.5
years the risk of early repayment is increasing but after four years of credit time the model also predicts a
strong increased hazard in default. In total the model predicts a slightly higher risk of payoff. The customer
was censored after 39 months, which corresponds to the long period of low risk to experience one of these
events, but with regard to this customer the model is not able to make a strong prediction to either default
or payoff.
The Hazard Graph 3 shows a high risk of payoff right in the beginning. The prediction represents the
customers true event time of experiencing the event of payoff after 1 months.
Hazard Graph 4 is similar to the third graph and also leads to a good prediction of payoff after 4 months
which matches the actual values of the customer. The graph shows a sudden increase in payoff risk
around 4.5 years that again decreases to a zero risk afterwards which is probably a result of the pattern
the model learned, but looks more like an unrealistic outlier.

Mostly the DeepHit models for single as well as for competing risks can already make great predictions on the
test dataset. With regard to the initial introduced business case, the predicted probability values of each
customer can be used in order to calculate the expected credit loss to set up the provisions as a counterbalance
to the recognised values of the loans. The formula of the expected credit loss is:

(Source: [27])

The output of survival analysis provides the probability values to fill the part of the formula in the above red box.
The more precise the prediction of the survival analysis models the more exact calculations of the expected credit
losses is possible which has an impact on the bank’s income statement.

7. Conclusion
We hope that our blog post gives everyone a clear overview of survival analysis and probably inspires to use it in
further academic or professional work. The standard survival statistics, such as the Cox proportional hazards
model, already allows to gain a meaningful insight from data without any sophisticated implementation of the
model.

The advanced extension of survival analysis models using machine learning practices gives more methodological
freedom. With proper hyperparameter tuning process it is possible to achieve more precise predictions of the
time-to-event target variable.

The format of the dataset is exceptionally important. In order to apply survival analysis techinques, the data has
to meet the requirements of the characteristic survival analysis data points: event, duration and valuable
features.

The implementation of more complex survival analysis models in Python is still in development. With increasing
popularity of this methods in different industries we hope that it is just a question of time that the variety of
functions within the survival analysis packages will rise.

Thanks for reading our blogpost and surviving it :)

You might also like