You are on page 1of 12

Pattern Recognition 131 (2022) 108876

Contents lists available at ScienceDirect

Pattern Recognition
journal homepage: www.elsevier.com/locate/patcog

A novel explainable neural network for Alzheimer’s disease diagnosisR


Lu Yu a, Wei Xiang b,∗, Juan Fang c, Yi-Ping Phoebe Chen b, Ruifeng Zhu d
a
College of Science and Engineering, James Cook University, Cairns, QLD 4878, Australia
b
School of Computing, Engineering and Mathmatical Sciences, La Trobe University, Melbourne, VIC 3086, Australia
c
Faculty of Information Technology, Beijing University of Technology, Beijing 100124, China
d
Department of Engineering ”Enzo Ferrari”, University of Modena and Reggio Emilia, Modena, Italy

a r t i c l e i n f o a b s t r a c t

Article history: Visual classification for medical images has been dominated by convolutional neural networks (CNNs) for
Received 24 August 2021 years. Though they have shown great performance on accuracy, some of them provide decisions that are
Revised 19 May 2022
hard to explain while others encode information from irrelevant or noisy regions. In this work, we try to
Accepted 26 June 2022
close this gap by proposing an explainable framework which consists of a predictor and an explainable
Available online 28 June 2022
tool, so as to provide accurate diagnoses with intuitive visualization maps and prediction basis. Specifi-
Keywords: cally, the predictor is designed by applying attention mechanisms to multi-scale features so as to learn
Explainable neural networks and discover class discriminative latent representations that are close to each brain volume’s label. Mean-
XAI while, to explain our predictor, we propose the novel explainable tool which includes a high-resolution
High-resolution heatmap visualization method and a prediction-basis creation and retrieval module. The former effectively inte-
MRI grates the feature maps of intermediate layers as well as the last convolutional layer, which surpasses
state-of-the-art visualization approaches in producing high-resolution representations with more accu-
rate localization of discriminative areas. While the latter provides prediction basis evidence via retrieved
volumes with similar latent representations which are accessible to neurologists. Extensive experiments
show that the proposed framework achieves higher level of accuracy and explainability over other state-
of-the-art solutions. More importantly, it localizes crucial brain areas with clearer boundaries, less noises,
which matches background knowledge in the neuroscience literature.
© 2022 Elsevier Ltd. All rights reserved.

1. Introduction era of engagement with AI-based diagnosis of detecting AD at an


early stage automatically [2]. With the advance of magnetic reso-
Alzheimer’s disease (AD), the most common form of demen- nance technology, magnetic resonance imaging (MRI) data are of-
tia, which could induce movement disorders and a series of sub- ten provided to observe the development of brain tissue morphol-
sequent syndrome, has affected over 50 million people universally ogy related to AD [3]. Plenty of DL architectures have been pro-
and is growing rapidly [1]. Traditionally, the computer-aided detec- posed to classify AD using brain MRIs and gained satisfactory per-
tion of AD using machine learning methods develops feature de- formance [4]. However, despite their significant achievements, the
scriptor and classification systems. However, the hand-crafted fea- predictions of existing models are not faithful with expected rea-
tures suffer from the subjectivity and cannot generalize well across soning. That is, they do not provide any explicitly visual or other
instances. Thanks to extensive research on applications of deep forms of explainable information associated with the final output.
learning (DL) such as CNNs, medical scientists have sought a new This becomes a major hurdle to apply these techniques on a mass
scale due to the lack of humans’ trust.
Explainable Artificial Intelligence (XAI) is an emerging sub-field
R
Data used in preparation of this article were obtained from the Alzheimer’s Dis- of AI pursuing to capture the properties that have influence over
ease Neuroimaging Initiative (ADNI) database (adni.loni.usc.edu). As such, the in- the decision of a model [5]. To fully uncover the CNNs, several
vestigators within the ADNI contributed to the design and implementation of ADNI works have proposed to build interpretable CNN models. Wang
and/or provided data but did not participate in analysis or writing of this report.
et al. [6] proposed a general approach to train interpretable con-
A complete listing of ADNI investigators can be found at: http://adni.loni.usc.edu/
wp-content/uploads/how_to_apply/ADNI_Acknowledgement_List.pdf. volutional filters in CNN models, wherein each filter represents a

Corresponding author. certain part of the object. Lee et al. [7] designed to make final de-
E-mail addresses: lu.yu@my.jcu.edu.au (L. Yu), w.xiang@latrobe.edu.au (W. cisions based on the regional abnormality representation by use
Xiang), fangjuan@bjut.edu.cn (J. Fang), phoebe.chen@latrobe.edu.au (Y.-P. Phoebe of complex nonlinear relationships among voxels. However, most
Chen), reefing.z@gmail.com (R. Zhu).

https://doi.org/10.1016/j.patcog.2022.108876
0031-3203/© 2022 Elsevier Ltd. All rights reserved.
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

Fig. 1. Visualization results of state-of-the-art methods for a AD patient: (a) CAMERAS [8]; (b) Grad-CAM [9]; (c) Grad-CAM++ [10]; (d) Score-CAM [6]. All of them provide
blurry heatmaps or recognition of irrelevant noise.

of existing methods only provide blurry heatmaps or recognition of studies are proposed to utilize hand-crafted features extracted
of atrophy with irrelevant noise (Fig. 1), this can be attributed to from MRI data in combination with different models. Zhang et al.
the fact that the leveraged last convolutional layer only extracts [11] proposed a multi-task feature selection (MTFS) method that
global features and misses the small attributes and discrepancies. selects subsets of features from each modality. Based on which,
Therefore, they are not be able to provide enough details to pre- Liu et al. [12] developed a inter-modality feature selection method
cisely recognize crucial areas, and fail to localize small differences (IMTFS) to process the complementary inter-modality features. Zhu
in medical imaging diagnosis. et al. [13] adopted manifold regularized multi-task learning for
Different from existing XAI works in the literature, we aim AD diagnosis. Moreover, Shi et al. [1] first developed a nonlin-
to develop an explainable framework for automated diagnosis of ear feature engineering module, then used the support vector ma-
AD capable of providing accurate predictions with fine-grained chine (SVM) to identify AD patients. Cao et al. [14] explicitly ex-
heatmaps and prediction reasoning. We first build an explainable tracted subset features and Region-of-Interests (ROIs), then com-
network dubbed MAXNet with two novel modules, Dual Attention bined these features in a multi-task learning framework for AD
Module (DAM) and Multi-resolution Fusion Module (MFM), to cap- diagnosis. Gerardin et al. [15] modeled the shape of hippocam-
ture and fuse multi-resolution features. Intuitively, we hope the pus regions via spherical harmonics and developed a classification
MAXNet could learn representations containing all of the necessary procedure to automatically discriminate between patients. Stefan
voxel information for the correct predictions. Therefore, we de- et al. [16] employed various measurements to obtain expressive
sign the cluster and contrastive loss functions to make the model MRI biomarkers and fed them into a linear discriminant analysis
learn and extract semantically informative latent features of the system. However, these traditional computer-aided methods learn
target label. Second, to provide high-resolution heatmaps and pre- hand-crafted representations can be prone to subjectivity, and are
diction reasoning, we propose an explainable tool which consists difficult to be optimized.
of a novel visualization method termed High-resolution Activa-
tion Mapping (HAM), and a Prediction-basis Creation and Retrieval
(PCR) module. The former is for yielding fine-grained heatmaps for 2.2. Deep learning methods for AD diagnosis
disease areas, while the latter creates a prediction reference-set
during training, in which subjects similar to a query volume are Recently, DL techniques have made great progress on AD
retrieved during test, to enhance the explainability for predictions. diagnosis with the benefit of automatic abstraction of multi-level
In short, our main contributions are three-fold: latent features. Chen and Xia [17] jointly used iterative spare and
DL methods to learn representations of critical cortical regions
• We propose a novel 3D interpretable model, dubbed MAXNet, that are used to diagnose AD. Su et al. [18] introduced domain
which can effectively aggregate multi-scale features for AD de- adaptation to utilize feature distributions of brain images across
tection and learn latent features that are representative to each multiple sites for binary classification. Pan et al. [19] proposed
volume’s label; a joint deep learning framework to model the disease-image
• We present a novel high-resolution visualization method, specificity as well as the disease diagnosis using incomplete MRI
termed High-resolution Activation Mapping (HAM), that pro- and fluorodeoxyglucose positron emission tomography (PET) im-
duces high-resolution visual explanations for the precise local- ages. Basaia et al. [4] built a 3D CNN for MRI data to distin-
ization of disease areas through aggregating the learned dis- guish among AD, c-MCI and s-MCI without any prior feature de-
criminative representations of intermediate layers as well as the sign. Lei et al. [20] introduced a convolutional network based
last convolutional layer; on longitudinal multiple time points data for identifying AD sub-
• We propose a Prediction-basis Creation and Retrieval (PCR) jects. Lian et al. [21] proposed a hierarchical fully convolutional
module, which leverages latent representations to collect simi- network that automatically learns multi-scale feature represen-
lar reference samples as visual evidence for the case analysis of tations in the whole brain structural magnetic resonance imag-
AD. ing (sMRI) data for AD diagnosis. Kröll el al. [22] employed var-
ious residual structures to facilitate training and obtain infor-
2. Related work mation from previous layers. Gopinath et al. [23] proposed a
new graph convolutional network for processing surface-valued
2.1. Traditional methods for AD diagnosis data to output subject-based classification and regression. Despite
the promising results of these models, almost all the prior ap-
There has been much interest in feature selection techniques proaches are designed with complex modules that are difficult to
to assist the diagnosis of AD using individual brain MRIs. A group interpret.

2
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

2.3. Explainability Table 1


List of symbols and their descriptions.

A number of papers have been proposed to visualize a model’s mi The ith MRI volume in the training dataset
predictions by highlighting important regions that are believed yi Label of mi
to be intuitive to end-users. If we consider an image classifica- yˆi Predicted label of mi
tion task as an example, a “good” visual explanation based on Riref Reference set for label yi
the model should be able to be (a) class-discriminative (i.e., lo- Rc Reference sample where c = 1, 2, 3
mTk The kth MRI volume in the testing dataset
calize the category in the image) and (b) high-resolution. Zhou
yˆTk Predicted label of mTk
et al. [24] introduced a technique called Class Activation Mapping pk Latent features for the kth MRI volume mTk
(CAM) for identifying informative areas by a certain kind of clas- pc Latent features for the cth reference sample Rc
sification CNNs which do not have fully-connected layers. Substan-
tially, it utilized the last convolutional layer before the global pool-
ing layer and combined weighted activation maps to produce ex- diate layers as well as the last convolutional layer
plainable heatmaps, it turned out to be highly class-discriminative,     ∂ si 
but with quite blurry outputs as an undesirable attribute. Beyond AHAM = nU ReLU N ( Z1n ∂ F n )F n , (1)
that, Grad-CAM [9] generalized CAM to a relatively large set of where Zn is the number of filters in the nth layer, si is the pre-
CNN models without requiring a specific architecture, by backprop- dicted score, and F n is the nth activation map. N (· ) and U (· ) rep-
agating the gradient of a target class with respect to the pixel resent the normalization and up-sampling functions, respectively.
intensities. Jalwana et al. [8] proposed a mechanism to generate Moreover, the taskof evidence presentation is to firstly cre-
high-resolution heatmaps with improved activation map upsam- 
ate a reference set Riref for each label yi from the train-
pling that corresponds to a model’s logic. However, gradients for
ing dataset {m, m}, where Riref  m. Afterwards, we can retrieve
a deep learning model can be noisy and also easily to get van-
samples {Rc , yc } 3c=1 that have the most similar latent features
ished in sigmoid function or an activation function like ReLU. So
compared to the input volume during the test phase. Table 1
Wang et al. [6] acquired each weight regarding individual activa-
shows the list of symbols and theirs descriptions used in our
tion map through feeding it into the network, and the heatmaps
work.
are yielded by the association between corresponding weights and
maps. Although these algorithms achieved remarkable level of im-
provements, they either did not combine the advantage of both 3.2. Framework overview
sides (class-discriminative and high-resolution), or get stuck into
one of them. We propose an explainable framework for automated diagno-
In the domain of medical image analysis, Hannun et al. [25] uti- sis of the AD from MRI volumes, which is capable of providing
lized the electrocardiogram tool to interpretate the clinical ECG accurate classification results with fine-grained visualization maps
process in an end-to-end manner. Afshar et al. [26] took advan- and a prediction basis. The schematic of our framework including
tage of capsule networks to model nodule features and provide MAXNet, HAM, and PCR is in Fig. 2. We first craft the so-called
potential interpretability of the model. Malhotra et al. [27] pro- Multi-scale Attention eXplainable Network (MAXNet), which can
posed a multi-task model to predict COVID-19 in chest X-ray im- be trained in an end-to-end fashion, so as to address the afore-
ages and segmented the lung regions with COVID-19 symptoms. mentioned challenging issues and power the visual interpretabil-
Xie et al. [28] conducted three iterations of design activities to for- ity elaborately. Then we present a new high-resolution visual-
mulate a system, which enables clinicians to explore and under- ization approach, referred to as High-resolution Activation Map-
stand AI-based chest X-ray analysis. Chittajallu et al. [29] presented ping (HAM), which extracts salient features related to the AD
a human-in-the-loop XAI system for content-based image retrieval (e.g., the atrophy of cerebral cortex and hippocampus.) to inter-
of video frames similar to a query image from invasive surgery pret model decisions. Furthermore, a reference set Rref is created
videos for surgical education. Jin et al. [30] proposed an attention by the Prediction-basis Creation and Retrieval module during train-
guided network to localize image biomarkers and provide intuitive ing to extract and save relevant samples for certain labels, and is
explanations. Hu et al. [31] developed an interpretable multimodal then used during testing to provide evidence of samples Rc with
fusion model by utilizing the Grad-CAM. Nevertheless, existing DL labels yc .
methods do not provide high-resolution heatmaps and thus can
not give reliable explanations. In our proposed method HAM, we 4. MAXNet architecture
aim to produce visual explanations with fine-scale information as
well as being class-discriminative. There are several essential modules that constitute the
MAXNet: 1) the staged feature extraction flow; 2) the Dual Atten-
tion Module (DAM); 3) the Multi-resolution Fusion Module (MFM);
3. Proposed methodology 4) the cluster and contrastive loss functions. In comparison to
mainstream CNN models, MAXNet has multi-resolution fields as
In this section, we formulate the problem under consideration highly complementary to capture accurate localization of homoge-
and introduce our proposed MAXNet, HAM and PCR. neous areas. It is also able to learn latent representations that are
close to each volume’s label with the proposed cluster and con-
3.1. Problem formulation trastive loss functions.

Firstly, given labeled training data {mi , yi } M


i=1
containing M 4.1. Staged feature extraction
samples wherein yi ∈ {0, 1} is a binary class label referring to the
presence/absence of the AD and mi ∈ R3 is an MRI volume, the pro- We start our network with a high-to-low convolutional stream,
posed MAXNet aims to predict the corresponding diagnosis label yˆi which would be devised as five stages in the MAXNet as shown
given input mi . in Fig. 2. Each stage is a convolutional block, which is sequen-
On the other hand, the fine-grained visualization task is to pro- tially made of a convolutional layer, a batch normalization (BN)
vide heatmap AHAM by integrating the activation maps in interme- layer, a rectified linear unit (ReLU), and a max-pooling layer. We

3
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

Fig. 2. Schematic of the overall framework, which consists of the explainable model MAXNet and the explainable tool, i.e., HAM and the PCR module. In MAXNet, the
high-to-low convolutional stream forms several stages (stages 1–5). We define F n , (n ∈ [1, 2, 3, 4, 5] ) as the intermediate activation response of the nth stage before the
max-pooling layer, and Gn as the final output of each stage n after max-pooling. F 3 and F 4 are leveraged to form the voxel-wise feature maps P 3 , P 4 , and the depth-wise
feature maps D3 and D4 via the DAM (see Fig. 3)respectively. Note that G5 from the last convolutional layer only extracts global features of the pathological abnormalities
and misses the small subjects and discrepancies. Eventually, P 3 , D3 , P 4 , D4 , and G5 are fused via the MFM (see Fig. 4) to produce the classification label yˆi . Subsequently,
visual explanations AHAM are obtained via HAM by multi-stage aggregation, and PCR is used to retrieve three reference samples R1 , R2 and R3 most similar to the input
volume, which are displayed as the evidence with ground-truth labels y1 , y2 , y3 .

Fig. 3. Block diagram of the Dual Attention Module (DAM), which is embedded into several stages of MAXNet, with the objective of capturing both voxel-wise and depth-
wise dependencies and variations of feature maps P n and Dn in hidden layers simultaneously.

make several adaptations to create our high-to-low stream. First, 4.2. Dual attention module (DAM)
as in stages 1 and 2 which produce larger spatial outputs com-
pared to their higher counterparts, the kernel size is set to be In a classical classification model, which usually extracts fea-
3 × 3 × 3 and the number of filters is set to 15 and 25 respectively tures by looking at each sub-area equivalently, much information
to save computational resources. Upon these, since the DAM would about local context clues could be excluded via upper layers. Thus,
be applied to stages 3 and 4, we increase the express capacity our intention is to devise a Dual Attention Module (DAM), as
across these two blocks with the number of convolutional kernels demonstrated in Fig. 3, to capture both voxel-wise and depth-wise
to be the same as in stage 5, i.e., 50. We use convolutional layers dependencies from both high and low resolution feature maps.
with kernel size of 3 × 3 × 3 for stages 3 and 4, and 1 × 1 × 1 for Consequently, the multi-resolution relationships can be well rep-
stages 5. resented for the model decisions.

4
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

Fig. 4. Block diagram of the Multi-resolution Fusion Module (MFM), which aggregates multi-resolution features P n , Dn , and G5 by use of several fully-connected layers.

4.2.1. Voxel-wise attention By predicting the result based on all P n and Dn with finer
For voxel-wise dependencies and differences between differ- and diverse receptive fields for views at both the voxel and chan-
ent stages, a voxel-wise attention module is applied to both cur- nel levels, the network is enhanced to concentrate on the most
rent stage and the final stage as depicted in Fig. 3(a). Specifically, considerable partial regions, boost the influence of subtle distinc-
the encoding process for stage n (n ∈ [3, 4]) involves three steps: tions, and inhibit the background or trivial noise. Briefly speaking,
Firstly, we map F n and G5 onto a mutual embedding space the advantages of this proposed mechanism can be proclaimed on
three fronts: 1) employing the voxel-wise attention allows low-
Fˆ n = W f (F n ), GA = Wa (G5 ), (2)
scale stages to pay more attention on learning both local and
where W f (· ) contains one convolution layer as Conv(filter=25, global context attributes; 2) with the elaborate design of a depth-
kernel-size=1, strides=1), and Wa (· ) is composed of one learn- wise attention block, the model is extended to learn complex and
able convolution layer Conv(filter=25, kernel-size=1, strides=1) flexible correlations between 3D features; 3) the DAM is significant
followed by one up-sampling layer. After projection we obtain since data of medical imaging are intrinsically noisy. In this case,
      a trainable block other than a linear parameter may be easier to
Fˆ n ∈ RDn ×Cn ×Hn ×Wn and GA ∈ RDn ×Cn ×Hn ×Wn . Secondly, we perform
achieve the global optimum.
an element-wise product between GA and Fˆ n to obtain the follow-
ing interaction-aware attention matrix:
4.3. Multi-resolution fusion module (MFM)
ci, j = Fˆ n (i ) × GA ( j ), (3)
In order to encourage the diversity of learned feature activa-
where ci, j represents the correlations of voxels <i, j> for all ele- tions and enforce these features to be close to the label of its input,
ments in the activation feature maps. Note that feature G5 should we construct the Multi-resolution Fusion Module (MFM) to com-
have coarser but semantically stronger feature responses. Thus, Fˆ n bine multi-resolution features. The structure of the MFM is illus-
and GA have the same resolution but different temporal contextual trated in Fig. 4.
coverage. Subsequently, we normalize ci, j by We argue that a fusion module is supposed to be adaptive and
ci, j − min(ci, j ) can be fine-tuned in accordance with specific application scenarios.
ri, j =  . (4) Firstly, we combine P n and Dn as follows
− min(ci, j )]
i, j [c i, j
Fˆ n = β1 P n + β2 Dn , n ∈ [3, 4], (8)
The above normalization operation bears some resemblance to the
soft-max function but does not generate a sparse output. Finally,we where β1 and β2 are set to be 0.5 initially and learnable by the
define a more discriminative representation P n by back propagation algorithm.
N Secondly, we define a set of important class discriminative la-
Pn = i=1 r × F in . (5) tent features p for the input mi as follows:
     
Through the use of this attention module, the proposed net-
p(mi ) = max ReLU N 1
Zn m p∈Ri, j,k Fˆ n (i, j, k ) ,
work is able to tell exactly where to look at the slice level, and
      (9)
further retrieve the visual explanation of a finer scale.
U ReLU N 1
Z5 m p∈Ri, j,k G5 (i, j, k ) , n ∈ [3, 4],

4.2.2. Depth-wise attention where Zn /Z5 is the number of convolution filters in Fˆ n /G5 . By using
As depicted in Fig. 3(b), we propose a depth-wise attention Eq. (9), p offers more accurate localization of important features by
module to perceive 3D context between slices. With Fˆ n and GA considering maximum values both from the intermediate features
acquired by Eq. (2), we do a transpose operation on them to get and the last convolutional features. We then project these latent
T      
Fˆ n ∈ RHn ×Wn ×Dn ×Cn and GTA ∈ RHn ×Wn ×Dn ×Cn . Then the inner product features as
is taken
z j = − log(||zˆ j − p j ||22 ) + η, j ∈ {1, . . . , N }, (10)
T
cDi, j = Fˆ n (i )GTA ( j ), (6) where zˆ j is extracted from Fˆ 3 , Fˆ 4 and G5 after FC layers. η is em-
and the obtained cD is normalized by Eq. (4) to get rD . Subse- pirically set to 1e − 4. The final output is produced by
i, j i, j
quently, a depth-wise feature map Dn is computed by yˆi = argmax(softmax(F C (z ))). (11)
N
Dn = D i The intuition behind this is that the predicted score F C (z ) w.r.t.
i=1 r F n . (7)
yˆi is high when latent features p preserved by z are important.

5
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

In that case, the model is able to learn good representations by flowing back from the DAM. Then the visualization map AHAM is
merely asking the latent features p to be close to its predicted la- given by
bel.      
∂ si
m ∂ F 3 F 3 (i, j, k )
1
AHAM = U max ReLU N m
,
4.4. Loss function    
 (17)
∂ si
U ReLU N 1
m m ∂ F n F n (i, j, k ) , n ∈ [4, 5],
Although both DAM and MFM provide a strong capacity for fea-
ture learning, it is non-trivial to obtain interpretable representa- where F 5 = G5 , m and m are the number of convolution filters for
tions without additional regularization. Therefore, we propose to F 3 and F n .
learn a meaningful latent space via additional objective constraints Overall, this method takes intermediate activations as well as
i.e., cluster loss Lcls and contrastive loss Lctr . With which the most the features from the last convolutional layer as input, which is
important features are clustered around the ground-truth label, certainly different from state-of-the-art methods which only em-
and are well separated from features related to other labels. We ploy the last convolutional features. Therefore, compared to other
achieve this goal by jointly optimizing the following loss function techniques that produce blurry maps and lose too much discrimi-
native information, our HAM approach is able to learn and identify
high-resolution features of brain areas through capturing diverse
L = Lce + α1 Lcls + α2 Lctr , (12)
cues successfully.

1
n
5.2. Prediction-basis creation and retrieval (PCR)
Lcls = minzˆ i ||zˆ − pi ||22 , pi ∈ pyi , (13)
n
i=1
Figure 2 illustrates the proposed PCR module for our explain-

m able model MAXNet. Our intuition is that we want to identify sam-
1
Lctr = − minzˆ i ||zˆ i − pi ||22 , / pyi ,
pi ∈ (14) ples that have morphologically similar features compared to the
m input volume. First, let us define the so-called reference sample
i=1

      and formulate the problem that the PCR aims to tackle.


pyi = max ReLU N 1
Zn m p∈Ri, j,k syni Fˆ n (i, j, k ) , Definition 1 (Reference Sample). Given an MRI volume mTk , k ∈
      (15) [1, 2, . . . , K ] in the test dataset and the proposed model MAXNet
U ReLU N 1
Z5 m p∈Ri, j,k sy5i G5 (i, j, k ) , n ∈ [3, 4], ψ (· ), Rc is called a reference sample of mTk when
y
where sni is the predicted score, Lce is the cross-entropy loss and Rc = argminRc D(Rc , mTk ) s.t. ψ (mTk ) = ψ (Rc ), (18)
α1 , α2 are hyper-parameters. Intuitively, minimizing the Lcls en- where D(· ) is a function of evaluating the similarity between Rc
courages the model to have at least one representation similar to and mTk .
its true label’s latent features, while the contrastive loss Lctr pe-
nalizes the similarity between its representations and other labels’ Problem 1. Given the mTk and our model ψ (· ), let yˆTk = ψ (mTk ),
features. and pk are the latent features of mTk . The goal is to retrieve ref-
Consequently, with the Lcls and Lctr terms, the loss function in erence samples {Rc , yc } 3c=1 and corresponding latent features pc
Eq. (12) encourages our model to learn and cluster the latent fea- similar to pk , with the objective of providing instance-level justifi-
tures into a semantically meaningful space, which facilitates the cations for the model output yˆTk .
prediction of MAXNet and the generation of fine-grained inter-
pretable heatmaps. We found that the well-trained model MAXNet is able to learn
pivotal and various features in brain images, e.g., the atrophy of
5. Explaining the MAXNet predictions cerebral cortex and hippocampus, the enlargement of frontal and
temporal horns of the lateral ventricles, and the enlarged sulcal
In this section, we propose HAM to capture fine-grained spaces with atrophy of gyri. These pathological changes are be-
heatmaps AHAM from a different perspective. Also, this section will lieved to be important for AD diagnosis by expert clinicians [2]. As
elaborate on the PCR module, to provide supplemental evidence of a result, we consider the similarity of two generated latent repre-
reference samples with ground-truth labels. sentations since they consist of a set of representative features for
predictions. That is, the retrieved pc is minimally different from pk .
Firstly, given each of training data {mi , yi }, MAXNet predicts
5.1. Proposed HAM for high-resolution heatmaps
its label yˆi and calculates its latent representation pyˆi by Eq. (15).
In what follows, we construct an auxiliary diagnosis reference set
Most existing visualization methods only consider the last con-
based on the training dataset, which contains both volumes Rref as
volutional layer, which extracts global features of the pathologi-
potential reference samples and corresponding latent features p.
cal abnormalities and misses the small subjects and discrepancies.
Specifically, Rref consists of subsets Riref , and p contains subset pyi
Instead, we propose High-resolution Activation Mapping (HAM),
for yi , i ∈ [0, 1]. Algorithm 1 details the steps to generate Rref and
which consider values from both the intermediate features and the
p.
last convolutional features, so as to offer more accurate localiza-
Secondly, given an input mTk during testing, MAXNet yields the
tion. The detailed structure is depicted in Fig. 5. Recall yˆi is the
model’s predicted label for the input mi , and si is its correspond- prediction yˆTk and PCR retrieves reference samples Rc with label yc
ing predicted score before the soft-max function. Based on the pro- by Eq. (18). Where define the similarity evaluation function D(· ) as
posed DAM operated onto F n in stages 3 and 4, it follows follows
1 c
n
∂ si D(Rc , mTk ) = || p j − pkj ||22 ,
T
pkj ∈ pyk , pcj ∈ pyc , (19)
= W 1n + W 2n , (16)
∂Fn n
j=1
∂s T
where ∂ F i is the gradient of si w.r.t. F n , and it is decomposed into where || · ||2 is the L2 norm, pyk and pyc are calculated by Eq. (15).
n
two terms. One is W 1n derived from stages, and W 2n is the gradient This applies to our intuition that a reference sample should contain

6
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

Fig. 5. Block diagram of the High-resolution Activation Mapping (HAM). Each arrow shows the gradient of the classification logit. Our method takes intermediate activations
as inputs, and considers the maximum values from the intermediate features F 3 and F 4 as well as the final activation G5 , which offers more accurate localization.

Algorithm 1 Generate the reference set Rref and latent features p. probability P (X = mTk |Y = yk ) can be written as follows
Input:ψ (· ), {mi , yi } M ,
{yi } 1
i=1 i=0 P (X = mTk |Y = yk ) = P (X = mTk |zk1 (mTk ) = z1 , . . . , zkn (mTk ) = zn , Y = yk )
Output:Riref , pyi
·P (zk1 (mTk ) = z1 , . . . , zkn (mTk ) = zn |Y = yk ).
1: Initialize: Riref ← {}, pyi ← {}
2: for mi ∈ m do (21)
3: yˆi ← ψ (mi )
Based on Eq. (21) it can be concluded that if X = mTk , then
4: pyˆi ← Eq.(15 )
5: if yˆi == yi then the probability of zk1 (mTk ) = z1 , . . . , zkn (mTk ) = zn should be 1. Sub-
6: Riref ← Riref ∪ mi sequently, we make another assumption
7: pyi ← pyi ∪ pyˆi P (X = m|zk1 (m ) = z1 , . . . , zkn (m ) = zn , Y = yk )
8: end if = P (X = m|z1j (m ) = z1 , . . . , znj (m ) = zn , Y = y j ), (22)
9: end for ∀m ∈ mT , ∀yk , y j ∈ {0, 1},
10: return Riref , pyi
which means that for a given label yk or y j , the probability that
m’s latent features z (m ) are most similar to yk or y j is essentially
the same. Plugging Eqs. (21) and (22) into Eq. (20) gives rise to
the discriminative features that are highly aligned with the input
sample’s.
P (zk (m ) = z1 , . . . , zkn (m ) = zn |Y = yk )P (Y = yk )
With Eqs. (18) and (19), PCR is able to retrieve samples Rc with P (Y = yk |X = mTk ) =  1 j ,
j P (zi (m ) = zi , . . . , zn (m ) = zn |Y = y j )P (Y = y j )
j
latent representations pc that are morphologically similar to mTk ’s
features pk . pc can be particularly beneficial for internists since the (23)
most important features for classifying are retained. By utilizing
where P (zk1 (m )
= = zn |Y = yk ) = μ(||z −
z1 , . . . , zkn (m ) is pyk ||22 )
the PCR module, it will not only help us gain trust and acknowl-
the optimal distribution based on our loss function in Eq. (13).
edge of human users in evidence-centered fields such as medical
Based on the above equations, it can be concluded that a ref-
imaging, but also provide scientific confidence in real-world appli-
erence sample Rc which has latent representations pc theoretically
cations. Moreover, we believe this module can be extended to ben-
guarantees the accurate information of instance-level explanations
efit other interpretable processes for multi-classification problems. 
provided by Rc if it satisfies Rc = argminRc 1n nj=1 || pcj − pkj ||22 .
We further analyze the impact of a reference sample on the
5.3. Reasoning process of our PCR module orginal prediction accuracy.

In this section, we present a probabilistic explanation for the Theorem 1. Given an MRI volume mTk and model ψ (· ), Rc is a ref-
proposed PCR’s reasoning process. erence sample of mTk with label yc . yˆTk = ψ (mTk ), the latent represen-
y
Firstly, we consider the classification task as a conditional prob- tations for yˆTk and yc are pk and pc , respectively, and z j k is extracted
ability estimation problem, in which our goal is to obtain the con- using Eq. (10). Assume that:
ditional distribution P (Y = yk , X = mTk ). Inspired by Bayes’ Theo-

rem, the problem can be further formulated as • ∃ 0< ξ <1, || pcj − pkj ||2 ≤ ( 1 + ξ − 1 )||zyj k − pkj ||2 and ||zyj k −

P (X = mTk |Y = yk )P (Y = yk ) pkj ||2 ≤ 1 − ξ j ∈ {1, . . . , n};


P (Y = yk |X = mTk ) =  . (20) y
the weight in the last FC layer is 1 for z j k = −argmaxzˆ j log(||zˆ j −
i P (X = mk |Y = yi )P (Y = yi )
T •
y
p j i ||22 ) + η, and 0 otherwise.
Then we define a group of latent features zki (mTk ) = argminzi ||zi −
pi ||22 , pi ∈ pyk , where zi is sampled from zˆ i , pyk is computed by Then using pc in lieu of pk can modify ψ (· )’s predicted logit for

Eq. (15). Here we assume that for any input mTk , there exists only at most k = j log(1 + ξ ), j ∈ {1, . . . , N }. If the logit score between
one latent representation zi that is most similar to pi . As a result correct and incorrect labels are at least 2k , then the latent features
we can make a reasonable assumption that zk (mTk ) contain suf- pc of the reference sample Rc can be used to correctly explain ψ (· )’s
ficient information about yk . Then we prove the label-conditional decision about mTk .

7
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

Proof. Denote by sk ψ (· )’s logit score with the correctly predicted c j = 0. Likewise, the intuition behind this is to evaluate the simi-
label yˆTk . Then it follows from Eq. (10) that larity of prediction changes by adding features from pc / pk into mTk .
A larger ρm I means pc causes a similar increase compared to pc
log(||z j k − pkj ||22 ) + η, j ∈ {1, . . . , N }. k ,rc
y
sk = − (24)
in prediction accuracy when added. In the following sections, we
j
will provide experimental results concerning these two evaluation
Let k be the logit change by choosing reference sample Rc to ex- metrics.
plain ψ (· )’s decision about mTk . Then we have
6. Experimental results and discussions
 ||zyk − pc ||2 
j j 2
k = s k − s c = log (25)
j
|| y
z jk − pkj ||22 6.1. Dataset

According to the assumptions in Theorem 1, we have || pcj − The datasets used in our experimental studies is the the

y y y Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset [32,33].
pkj ||2 ≥ ( 1 + ξ − 1 )||z j k − pkj ||2 , and ||z j k − pcj ||2 ≤ ||z j k − pkj ||2 +
We select the T1 weighted, pre-processed, baseline MRI data in the
|| pcj − pkj ||2 , which in turn gives us ADNI dataset, and a single scan per subject visit was selected. To
 ||zyk − pc ||2  address the issue of explainability and keep the classification task
2
log(1 + ξ ).
j j
k = log ≤ (26) simple, we only select two diagnosis groups in ADNI, which con-
j
|| y
z jk − pkj || 2
2 j tain 826 cognitively normal individuals and 422 Alzheimer’s pa-
tients with at least one session’s MRI volumes available. With the
Subsequently, we suppose that the corrected logic score sk is 2k
consideration of data heterogeneity, we carefully extract data sam-
larger than any other incorrect score si , i.e., sk ≥ si + 2k . There-
ples from the ADNI dataset to form three non-overlapping subsets.
fore, when using the reference sample Rc ’s latent features pc to
Each subset is further split into 1779 images for training, 427 for
explain ψ (· )’s decision about mTk , we have
validation, and 575 for testing. In order to avoid biased generaliza-
s c ≥ s k − k ≥ s i . (27) tion estimates due to same subject image similarities, each subject
is only selected into just one of the sets (i.e., the training, valida-
Given Eq. (27), we can claim that model ψ (· ) still can correctly tion, and test sets) for each subset. Finally, each of the volume is
classify the volume with the provided latent representations pc further cropped into size 169 × 208 × 179 for training and valida-
from the reference sample Rc .  tion, and test.
It is noted in the experiments that in our well-trained MAXNet,
the assumption always holds that sk ≥ si + 2k . Moreover, the dis- 6.2. Implementation details
y
tance || pkj − pcj ||2 is generally smaller than ||z j k − pkj ||2 , which ver-
Our proposed MAXNet is implemented in PyTorch and executed
ifies our assumptions and in turn confirms the effectiveness of our
on two Nvidia Volta V100 GPUs with 16 GB memory each. It is
PCR module. Empirically, the value of ξ is set to 0.24.
trained using the Adam optimizer with a weight decay value of
0.0 0 05, and the batch size is fixed to 8 samples. The initial learning
5.4. Metrics for evaluation of PCR
rate is set to be 0.0 0 01 and will be decayed according to a poly-
nomial schedule. We pre-train the model with the cross-entropy
In order to evaluate the accuracy of the reference set Rc , we
loss function Lce in Eq. (12) for the initial 20 epochs and fine-tune
design two evaluation metrics and conduct a series of experiments
it with the cluster and contrastive losses for 50 epochs. The value
to quantify the effectiveness of reference samples.
of hyper-parameter α1 , α2 are set to be 0.6 and 0.06 respectively
Definition 2 (Swap Deletion Confidence). after conducting extensive experiments.

(s(mTk ) − s(mTk  K ))  (s(mTk ) − s(mTk  C )) 6.3. Performance of MAXNet


ρmDk ,rc = , (28)
||s(mTk ) − s(mTk  K )||2 − ||s(mTk ) − s(mTk  C )||2
We compare the classification performance of the proposed ar-
where s(· ) is the predicted score,  is the hadamard product. K , C
chitecture with other interpretable models. Following [34], we re-
are with the same dimension as mTk . For each ki ∈ K , ki = 0 if the
sort to two XAI properties, i.e., continuity and selectivity (more de-
position i is located in pk , otherwise ki = 1. For c j ∈ C , c j = 0 if the tails can be found in [34]), to qualify the interpretability of the
position j is located in pc , otherwise c j = 1. Consequently, ρm D
,rc MAXNet.
k
measures the similarity between γ io and γ c .
j
The comparison results are presented in Table 2. It is noted
As is detailed in Theorem 1, pc and pk have been proved to that the existing interpretable models perform evaluation with dif-
be expressive features for mTk ’s prediction. Therefore, our intuition ferent cohorts of subjects and the indices of those subjects were
here is to evaluate if there are similar changes of the predictions not disclosed. To take into account data heterogeneity, we train
by removing features pc / pk from mTk . Arguably, a larger ρm D
,rc in- and evaluate our model on the three non-overlapping subsets ex-
k
dicates pc causes a similar decrease compared to pk in prediction tracted from the ADNI dataset. Table 2 reports not only the classi-
accuracy when removed. fication results of the comparison models, but also the number of
subjects used by each model. As can be observed from the table,
Following the Definition 2, we define the so-called Swap Inser- AlexNet 3D [34] shows a slightly better classification performance
tion Confidence from a different yet complementary angle. compared to VGGNet 3D [35]. Lee’s model in [7] yields better pre-
diction outcomes compared to its VGG or AlexNet counterparts, as
Definition 3 (Swap Insertion Confidence).
it derives complex nonlinear relationships for predefined regions.
s(mTk  K  )  s(mTk  C ) The experimental results presented in Table 2 demonstrates that
ρmI k ,rc = , (29)
||s(mTk  K  )||2 − ||s(mTk  C )||2 our proposed MAXNet clearly outperforms its competitors, as it
obtains the highest accuracy of 95.4%, and second highest AUC of
where for ki ∈ K  , ki = 1 if the position i is located in pk , otherwise 98.0% on subset 3 from the ADNI subset. Although Lee’s model
ki = 0. For c j ∈ C , c j = 1 if the position j is located in pc , otherwise in [7] achieves a slightly higher AUC score than our MAXNet, its

8
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

Table 2
Comparative results of various interpretable models on ADNI.

Model Subject-(AD / NC) ACC AUC Continuity Selectivity

Lee et al. [7] 198 / 229 0.9275 0.9804 - -


3DAN [30] 227 / 305 0.861 0.912 - -
Kroll et al. [22] 153 / 306 - 0.815 - -
VGGNet 3D [35] 47 / 56 0.766 0.863 - -
ResNet 3D [35] 47 / 56 0.854 0.794 - -
AlexNet 2D+C [34] 422 / 826 - 0.923 30.361 -0.059
AlexNet 3D [34] 422 / 826 - 0.898 37.887 0.215
VGG16 2D+C [34] 422 / 826 - 0.892 24.928 0.224
VGG16 3D [34] 422 / 826 - 0.886 41.879 0.039
MAXNet(Subset 1) 422 / 826 0.928 0.959 14.61 -0.79
MAXNet(Subset 2) 422 / 826 0.953 0.978 15.22 -0.87
MAXNet(Subset 3) 422 / 826 0.954 0.980 15.27 -0.71

accuracy is lower than ours and its results were validated on much Table 3
Comparative evaluation of HAM and other methods.
less MRI data (427 vs. 1248 subjects). As a result, the experimen-
tal results reported in Table 2 suggests that our proposed model Visualization method Insertion Deletion Runtime
MAXNet is capable of offering accurate diagnoses for AD vs. NC Grad-CAM [9] 0.492 0.822 0.027
classification. Last but not the least, it should be noted that the Grad-CAM+ [10] 0.554 0.743 0.030
heterogeneity among the data samples drawn from ADNI is not RISE [36] 0.603 0.576 29.32
considered in [7,22,30,34,35], while our work proves to be robust Score-CAM [6] 0.761 0.362 19.06
CAMERAS [8] 0.676 0.523 3.25
across the samples in the ADNI dataset.
HAM 0.801 0.263 0.092
The results of continuity and selectivity metrics are also shown
in Table 2. Note that we did not provide some models’ results
because they are not reported in relevant papers. Lower values
ical abnormalities and discriminative disease areas. Subsequently,
for continuity suggests that similar MRI scans have more simi-
RISE [36] and Score-CAM [6] obtain relatively unambiguous views
lar heatmaps. The continuity value of our proposed model is no-
of heatmaps, but still show known patterns of atrophy and miss to
tably lower compared to other models, it reveals that our MAXNet
highlight the lateral ventricles and hippocampus. Finally, our pro-
achieves a continuity of 15.27 on subset 3, and a selectivity of -
posed method HAM generates high-quality heatmaps (Fig. 6(g)),
0.87 on subset 2. Both are consistently superior to the models con-
which show discriminative localizations of brain abnormalities in-
sidered in [34]. Which means MAXNet is able to capture distinc-
stead of blurry heatmaps, making it outperform the existing works
tive patterns of disease areas by use of the latent feature repre-
remarkably. This can be attributed to the effective latent features
sentations and produce consistent results given similar MRI scans.
of aggregated representations from intermediate layers as well as
Moreover, selectivity quantifies the changes in prediction probabil-
the last convolutional layer.
ity of classification when removing the related features gradually,
Meanwhile, we also apply the above visualization methods to
lower values for selectivity means similar MRI volumes have sim-
our proposed MAXNet. Fig. 7(a) shows an example of a normal MRI
ilar relevant features in heatmaps [34]. As is observed, the selec-
scan, while Fig. 7(b)–(f) present heatmaps produced by the state-
tivity value obtained via AlexNet 2D+C, AlexNet 3D, VGG16 2D+C,
of-the-art approaches. As can be observed, these methods do not
VGG16 3D [34] varies slightly, this confirms that there is a low cor-
perform well and some of them even render strange visual out-
relation between the heatmaps and the predictions in these mod-
comes. More importantly, some areas such as lateral ventricles and
els. Instead, our model MAXNet achieves the lowest value of se-
hippocampus are still emphasized even this is a normal case with-
lectivity, which indicates the learned latent features are closed re-
out AD (Fig. 7(d)). Therefore, we realize that existing visualization
lated to its final outputs, and well demonstrates its distinctive in-
methods do not work well with our proposed explainable model
terpretability.
MAXNet. By contrast, the proposed HAM shows significant better
visualization results that can highlight small regions and ignore
6.4. Qualitative evaluation of HAM and PCR most of non-neuropathy areas as shown in Fig. 7(g).
We further evaluate HAM on three metrics: runtime, deletion
6.4.1. Faithfulness and complexity evaluation of HAM and insertion, which are adopted in several recent works [34]. An
It has been confirmed by neurologists that AD often causes at example of the deletion and insertion curves for a test volume is il-
least moderate cortical atrophy, enlargement of lateral ventricles, lustrated in Fig. 8, and the average performance for 2500 perturbed
and temporal enlarged horns which are most macroscopic in MRI volumes is given in Table 3. As is observed, compared to both
volumes [25]. For fair comparison against recent state-of-the-art Grad-CAM [9] and Grad-CAM++ [10], CAMERAS [8] provides better
methods, we implemented these approaches based on the pro- performance, this is potentially caused by the noisy heatmaps due
vided source code on the same network with an AUC of 0.992 to up-sampling operation. Score-CAM [6] obtains the state-of-the-
[4], the results can be found in Fig. 6. Figure 6(a) shows ground art performance on both deletion and insertion metrics, but the
truth segmented areas of the cerebral cortex, lateral ventricle and deletion curve is steeply convex, which means its feature selec-
hippocampus for an MRI volume with AD. Subjective comparisons tion is not stable. Instead, our proposed HAM outperforms other
shown in Fig. 6(b) and (c) indicate that one nearly cannot iden- approaches on both deletion and insertion metrics. This is also a
tify distinct areas from the heatmaps generated by both the Grad- proof that the heatmaps generated via HAM are able to capture
CAM [9] and Grad-CAM++ [10]. The heatmaps identify most of the most of the salient features of brain disease areas and contain
white matter, whose lesions are not macroscopic when evaluat- less noise. For the runtime metric, It can be observed that RISE
ing the brain disease and thus are not trustable visualization ev- [36], Score-CAM [6] and CAMERAS [8] all make considerable de-
idence. In addition, the CAMERAS [8] presented in Fig. 6(d) shows mands on the time. This is because augmented feature maps or in-
heatmaps with larger brain areas, but fails to identify patholog- puts were fed into models multiple times. As expected, Grad-CAM

9
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

Fig. 6. Visual results of visualization methods. Note that (a)-(f) were performed over a 3D CNN with an AUC of 0.992 [4]. (a) input with “AD” label. Ground truth of cerebral
cortex, lateral ventricle and hippocampus via FreeSurfer are highlighted, (b) Grad-CAM [9], (c) Grad-CAM++ [10], (d) CAMERAS [8], (e) RISE [36], (f) Score-CAM [6], (g)
proposed HAM-generated heatmaps which highlight enlarged sulcal spaces caused by atrophy and pathological abnormalities of cerebral cortex and hippocampus.

Fig. 7. (a) Input with “Normal” label, (b) Grad-CAM [9], (c) Grad-CAM++ [10], (d) CAMERAS [8], (e) RISE [36], (f) Score-CAM [6], (g) proposed HAM.

Fig. 8. (a) The deletion curve for Grad-CAM [9], Grad-CAM++ [10], CAMERAS [8], RISE [36], Score-CAM [6], and HAM. The x-axis represents the percentage of removed voxels,
while the y-axis is the corresponding predicted score. Specifically, a steeper slope indicates a better explanation. (b) The insertion curve for Grad-CAM [9], Grad-CAM++ [10],
CAMERAS [8], RISE [36], Score-CAM [6], and HAM. The x-axis shows the percentage of added voxels, and the y-axis is the corresponding predicted score. Specifically, a
fast-rising slope implies a better explanation.

[9] and Grad-CAM++ [10] are the fastest methods, because they Table 4
Comparative evaluation of PCR.
do not need to make multiple model predictions with perturbed
inputs in theory. Although Grad-CAM [9] achieves a fast runtime Swap Deletion Confidence Swap Insertion Confidence
than our HAM, it obtains substantially lower insertion (0.492 vs. W/O PCR PCR W/O PCR PCR
0.801 insertion) and higher deletion (0.822 vs. 0.263 deletion) than
ρm,r1 0.275 0.784 0.297 0.771
ours. These promising results suggest that our proposed approach ρm,r2 0.182 0.803 0.254 0.755
HAM is able to identify the saliency features that are responsible ρm,r3 0.345 0.818 0.278 0.764
for the model decisions.

6.4.2. Qualitative analysis of PCR a result, it obtains higher values of the Swap Deletion Confidence
Table 4 compares the average values of the proposed metrics and Swap Insertion Confidence. Fig. 9 displays an example of a test
Swap Deletion Confidence and Swap Insertion Confidence that are MRI volume with the label “AD” and three reference samples. It is
generated with and without PCR (i.e., retrieve samples randomly noted that the heatmaps of the reference samples are quite simi-
from MRI scans with the same label). Compared to the baseline lar to the input’s heatmaps. PCR highlights the atrophy of cerebral
which provides relevant samples stochastically, the PCR module is cortex, the pathological abnormalities of hippocampus, and the en-
able to find MRI cases with similar pathological abnormalities. As largement of lateral ventricle, which have salient features related

10
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

Fig. 9. Given the mTk , the explainable tool provides HAM-generated heatmaps and three reference samples Rc , c ∈ [1, 2, 3] whose latent features are most similar to mTk .

Table 5 is a drastic drop on both average accuracy and insertion score;


Contributions of individual modules in the proposed MAXNet on subset
(3) Finally, if we discard the MFM and concatenate all the multi-
1. Values indicating mask collapse are blank.
resolution feature maps, the model is trapped into local optima
Lcls + Lctr DAM MFM ACC Insertion Deletion and cannot be further improved.
 0.725±0.003 0.445 0.398
 0.752±0.075 0.546 0.446 6.6. Discussions
 0.763±0.008 0.327 0.338
  0.899±0.140 0.762 0.636 There are mainly two factors which may limit the performance
  0.856±0.021 0.526 0.553
  0.916±0.012 0.678 0.703
of our work. First, since the ground-truth annotations of patholog-
   0.953±0.002 0.801 0.263 ical abnormalities such as the bounding boxes in the brain regions
are not publicly available, our model learns the class discrimina-
tive latent features in a weakly supervised manner (i.e., under the
supervision of image-level class labels), which leads to inaccuracy
to the model decision. This confirms PCR’s ability to retrieve sam-
in the identification of different types of pathologies. Second, the
ples with similar disease areas.
proposed model does not incorporate any medical domain knowl-
edge as inputs. As a result, our model may not learn features that
6.5. Ablation study exactly match prior knowledge from relevant professionals such as
doctors and clinicians.
We evaluate the contribution of each constituent module in our
proposed MAXNet and HAM method on the subset of ADNI data, 7. Conclusion
i.e., Lcls + Lctr , DAM and MFM by removing at least one module at
a time. The detailed results are shown in Table 5. Several obser- In this paper, we integrated several novel modules to consti-
vations are made. (1) If we eliminate the loss Lcls + Lctr , then the tute a novel explainable framework by employing 3D deep learning
model actually adapts to make predictions or produce heatmaps techniques. A novel explainable network dubbed MAXNet was pro-
primarily based on the last convolution features, which fails to posed to classify AD. Among which, we introduced the DAM and
achieve high classification accuracy as MRI volumes are difficult MFM blocks to aggregate multi-resolution feature activation maps
to classify using coarse feature maps; (2) If we abandon the DAM, into the latent space, which was not only representative for high-
each sample can only be learned with single-resolution activation resolution explanations, but also crucial for model predictions. Be-
features stemming from the high-scale level. Consequently, there sides, the proposed cluster and contrastive losses encouraged the

11
L. Yu, W. Xiang, J. Fang et al. Pattern Recognition 131 (2022) 108876

model to learn interpretable features w.r.t. target labels in the la- [13] X. Zhu, H. Suk, D. Shen, Multi-modality canonical feature selection for
tent space. Additionally, we provided an explainable tool which is Alzheimer’s disease diagnosis, in: P. Golland, N. Hata, C. Barillot, J. Hornegger,
R.D. Howe (Eds.), Medical Image Computing and Computer-Assisted Interven-
comprised of HAM to generate voxel-wise information, and the tion (MICCAI), vol. 8674, 2014, pp. 162–169. Boston, MA, USA
PCR module to provide similar samples as the prediction basis. [14] P. Cao, X. Shan, D. Zhao, M. Huang, O. Zaiane, Sparse shared structure
By comparing the proposed model to other state-of-the-art meth- based multi-task learning for MRI based cognitive performance prediction of
Alzheimer’s disease, Pattern Recognit. 72 (2017) 219–235.
ods through extensive experiments, we validated the effectiveness [15] E. Gerardin, G. Chételat, M. Chupin, R. Cuingnet, B. Desgranges, H.-S. Kim,
of our model with good diagnostic accuracy. Moreover, the model M. Niethammer, B. Dubois, S. Lehéricy, L. Garnero, et al., Multidimensional
was capable of providing insightful explanations about its deci- classification of hippocampal shape features discriminates Alzheimer’s disease
and mild cognitive impairment from normal aging, Neuroimage 47 (4) (2009)
sions. Both factors are conducive to applying deep learning models
1476–1486.
to clinical applications. [16] L. Sørensen, C. Igel, A. Pai, I. Balas, C. Anker, M. Lillholm, M. Nielsen,
Despite the encouraging performance gained by our work, it A.D.N. Initiative, et al., Differential diagnosis of mild cognitive impairment
and Alzheimer’s disease using structural MRI cortical thickness, hippocampal
suffers two limitations. Firstly, the proposed model does not incor-
shape, hippocampal texture, and volumetry, Neuroimage 13 (2017) 470–482.
porate medical domain knowledge. For future studies, the model [17] Y. Chen, Y. Xia, Iterative sparse and deep learning for accurate diagnosis of
should potentially be further improved if prior domain knowledge Alzheimer’s disease, Pattern Recognit. 116 (2021) 1–10.
from medical professionals is integrated. Secondly, the latent fea- [18] J. Su, H. Shen, L. Peng, D. Hu, Few-shot domain-adaptive anomaly detection for
cross-site brain images, IEEE Trans. Pattern Anal. Mach. Intell. (2021) 1–18.
tures are learnt by our model in a weakly supervised manner due [19] Y. Pan, M. Liu, Y. Xia, D. Shen, Disease-image-specific learning for diagno-
to the lack of publicly available annotations of pathological abnor- sis-oriented neuroimage synthesis with incomplete multi-modality data, IEEE
malities, which may neglect possible pathological locations in the Trans. Pattern Anal. Mach. Intell. (2021) 1–15.
[20] B. Lei, M. Yang, P. Yang, F. Zhou, W. Hou, W. Zou, X. Li, T. Wang, X. Xiao,
brain. Therefore, we plan to develop approaches to learn and inte- S. Wang, Deep and joint learning of longitudinal data for Alzheimer’s disease
grate expert knowledge in our future work. prediction, Pattern Recognit. 102 (2020) 1–13.
[21] C. Lian, M. Liu, J. Zhang, D. Shen, Hierarchical fully convolutional network for
Declaration of Competing Interest joint atrophy localization and Alzheimer’s disease diagnosis using structural
MRI, IEEE Trans. Pattern Anal. Mach. Intell. 42 (4) (2020) 880–893.
[22] J. Kröll, S.B. Eickhoff, F. Hoffstaedter, K.R. Patil, Evolving complex yet inter-
The authors declare that they have no known competing finan- pretable representations: application to Alzheimer’s diagnosis and prognosis,
cial interests or personal relationships that could have appeared to in: 2020 IEEE Congress on Evolutionary Computation (CEC), 2020, pp. 1–8.
Glasgow, UK
influence the work reported in this paper.
[23] K. Gopinath, C. Desrosiers, H. Lombaert, Learnable pooling in graph convolu-
tional networks for brain surface analysis, IEEE Trans. Pattern Anal. Mach. In-
Acknowledgment tell. 44 (2) (2022) 864–876.
[24] B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, A. Torralba, Learning deep features
for discriminative localization, in: 2016 IEEE Conference on Computer Vision
This research was supported in part by the Australian Govern-
and Pattern Recognition (CVPR), 2016, pp. 2921–2929. Las Vegas, NV, USA
ment through the Australian Research Council’s Discovery Projects [25] A.Y. Hannun, P. Rajpurkar, M. Haghpanahi, G.H. Tison, C. Bourn, M.P. Turakhia,
funding scheme (project DP220101634). A.Y. Ng, Cardiologist-level arrhythmia detection and classification in ambula-
tory electrocardiograms using a deep neural network, Nat. Med. 25 (1) (2019)
References 65–69.
[26] P. Afshar, F. Naderkhani, A. Oikonomou, M.J. Rafiee, A. Mohammadi, K.N. Pla-
[1] B. Shi, Y. Chen, P. Zhang, C.D. Smith, J. Liu, A.D.N. Initiative, et al., Nonlinear taniotis, MIXCAPS: a capsule network-based mixture of experts for lung nodule
feature transformation and deep fusion for Alzheimer’s disease staging analy- malignancy prediction, Pattern Recognit. 116 (2021) 1–29.
sis, Pattern Recognit. 63 (2017) 487–498. [27] A. Malhotra, S. Mittal, P. Majumdar, S. Chhabra, K. Thakral, M. Vatsa, R. Singh,
[2] Z. Yang, I.M. Nasrallah, H. Shou, J. Wen, J. Doshi, M. Habes, G. Erus, A. Ab- S. Chaudhury, A. Pudrod, A. Agrawal, Multi-task driven explainable diagnosis
dulkadir, S.M. Resnick, M.S. Albert, et al., A deep learning framework identifies of COVID-19 using chest X-ray images, Pattern Recognit. 122 (2022) 1–13.
dimensional representations of Alzheimer’s disease from brain structure, Nat. [28] Y. Xie, M. Chen, D. Kao, G. Gao, X. Chen, CheXplain: enabling physicians to
Commun. 12 (1) (2021) 1–15. explore and understand data-driven, AI-enabled medical imaging analysis, in:
[3] B. Yu, L. Zhou, L. Wang, W. Yang, M. Yang, P. Bourgeat, J. Fripp, SA-LuT-Nets: 2020 CHI Conference on Human Factors in Computing Systems, 2020, pp. 1–13.
learning sample-adaptive intensity lookup tables for brain tumor segmenta- Honolulu, HI, USA
tion, IEEE Trans. Med. Imag. 40 (5) (2021) 1417–1427. [29] D.R. Chittajallu, B. Dong, P. Tunison, R. Collins, K. Wells, J. Fleshman,
[4] S. Basaia, F. Agosta, L. Wagner, E. Canu, G. Magnani, R. Santangelo, M. Filippi, G. Sankaranarayanan, S. Schwaitzberg, L. Cavuoto, A. Enquobahrie, Xai-cbir: Ex-
Alzheimer’s disease neuroimaging initiative, et al., automated classification of plainable AI system for content based retrieval of video frames from minimally
Alzheimer’s disease and mild cognitive impairment using a single MRI and invasive surgery videos, in: 2019 IEEE International Symposium on Biomedical
deep neural networks, Neuroimage 21 (2019) 1–8. Imaging (ISBI), 2019, pp. 66–69. Venice, Italy
[5] A.B. Arrieta, N.D. Rodríguez, J.D. Ser, A. Bennetot, S. Tabik, et al., Explainable [30] D. Jin, B. Zhou, Y. Han, J. Ren, T. Han, B. Liu, J. Lu, C. Song, P. Wang, D. Wang,
artificial intelligence (XAI): concepts, taxonomies, opportunities and challenges et al., Generalizable, reproducible, and neuroscientifically interpretable imaging
toward responsible AI, Inf. Fusion 58 (2020) 82–115. biomarkers for Alzheimer’s disease, Adv. Sci. 7 (14) (2020) 1–12.
[6] H. Wang, Z. Wang, M. Du, F. Yang, Z. Zhang, S. Ding, P. Mardziel, X. Hu, Score– [31] W. Hu, X. Meng, Y. Bai, A. Zhang, G. Qu, B. Cai, G. Zhang, T.W. Wilson,
CAM: score-weighted visual explanations for convolutional neural networks, J.M. Stephen, V.D. Calhoun, et al., Interpretable multimodal fusion networks
in: 2020 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), reveal mechanisms of brain cognition, IEEE Trans. Med. Imag. 40 (5) (2021)
Seattle, WA, USA, 2020, pp. 24–25. 1474–1483.
[7] E. Lee, J. Choi, M. Kim, H. Suk, Toward an interpretable Alzheimer’s disease [32] C.R. Jack Jr, M.A. Bernstein, N.C. Fox, P. Thompson, G. Alexander, D. Harvey,
diagnostic model with regional abnormality representation via deep learning, B. Borowski, P.J. Britson, J. L. Whitwell, C. Ward, et al., The Alzheimer’s disease
Neuroimage 202 (2019) 1–15. neuroimaging initiative (ADNI): MRI methods, J. Magn. Reson. Imaging 27 (4)
[8] M.A. Jalwana, N. Akhtar, M. Bennamoun, A. Mian, Cameras: enhanced reso- (2008) 685–691.
lution and sanity preserving class activation mapping for image saliency, in: [33] S.G. Mueller, M.W. Weiner, L.J. Thal, R.C. Petersen, C.R. Jack, W. Jagust, J.Q. Tro-
2021 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), janowski, A.W. Toga, L. Beckett, Ways toward an early diagnosis in Alzheimer’s
2021, pp. 16327–16336. disease: the Alzheimer’s disease neuroimaging initiative (ADNI), Alzheimer’s
[9] R.R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, D. Batra, Grad– Dementia 1 (1) (2005) 55–66.
CAM: visual explanations from deep networks via gradient-based localiza- [34] E. Nigri, N. Ziviani, F. Cappabianco, A. Antunes, A. Veloso, Explainable deep
tion, in: 2017 IEEE International Conference on Computer Vision (ICCV), 2017, CNNs for MRI-based diagnosis of Alzheimer’s disease, in: 2020 IEEE Interna-
pp. 618–626. Venice, Italy tional Joint Conference on Neural Networks (IJCNN), 2020, pp. 1–8. Glasgow,
[10] A. Chattopadhay, A. Sarkar, P. Howlader, V.N. Balasubramanian, Grad-CAM++: UK
generalized gradient-based visual explanations for deep convolutional net- [35] C. Yang, A. Rangarajan, S. Ranka, Visual explanations from deep 3D convolu-
works, in: 2018 IEEE Winter Conference on Applications of Computer Vision tional neural networks for Alzheimer’s disease classification, in: AMIA Annual
(WACV), 2018, pp. 839–847. Lake Tahoe, NV, USA. Symposium Proceedings, 2018, pp. 1571–1580. San Francisco, CA
[11] D. Zhang, Y. Wang, L. Zhou, H. Yuan, D. Shen, Multimodal classification of [36] V. Petsiuk, A. Das, K. Saenko, RISE: randomized input sampling for explana-
Alzheimer’s disease and mild cognitive impairment, Neuroimage 55 (3) (2011) tion of black-box models, in: 2018 British Machine Vision Conference (BMVC),
856–867. Newcastle, UK, 2018, pp. 1–13.
[12] F. Liu, C. Wee, H. Chen, D. Shen, Inter-modality relationship constrained multi-
-modality multi-task feature selection for Alzheimer’s disease and mild cogni-
tive impairment identification, Neuroimage 84 (2014) 466–475.

12

You might also like