Professional Documents
Culture Documents
SPIEDigitalLibrary.org/conference-proceedings-of-spie
Fei Liu, Huabin Wang, Yonglin Chen, Yu Quan, Liang Tao, "Convolutional
neural network based on feature enhancement and attention mechanism for
Alzheimer's disease prediction using MRI images," Proc. SPIE 12083,
Thirteenth International Conference on Graphics and Image Processing
(ICGIP 2021), 120830X (16 February 2022); doi: 10.1117/12.2623580
ABSTRACT
Nuclear Magnetic Resonance Imaging(MRI) is the mainstream way to predict Alzheimer's disease, but the
accuracy of traditional machine learning method based on MRI to predict Alzheimer's disease is low. Although
Convolutional Neural Network(CNN) can automatically extract image features, convolution operations only focus
on local regions and lose global connections. The attention mechanism can focus on local and global information at
the same time, and improve the performance of the model by strengthening the key information to suppress invalid
information.Therefore, this paper constructs a deep CNN based on multiple attention mechanisms for Alzheimer's
disease prediction. Firstly, the MRI image is enhanced by cyclic convolution to enhance the feature information of
the original image, so as to improve the prediction accuracy and stability. Secondly, multiple attention mechanisms
are introduced to re-calibrate features and adaptively learn feature weights to identify brain regions that are
particularly relevant for disease diagnosis. Finally, an improved VGG model is proposed as the backbone network.
The maximum pooling is adjusted to average pooling to retain more image information and the network efficiency
is improved by reducing the number of neurons in the fully connected layer to suppress over-fitting merging. The
experimental results show that the prediction accuracy, sensitivity and specificity of Alzheimer's disease prediction
method based on multiple attention mechanism are 99.8%, 99.9% and 99.8%, respectively, which is superior to the
existing mainstream methods.
Keywords: MRI, Alzheimer's disease, Convolutional Neural Network, Attention mechanism, Cyclic convolution
1. INTRODUCTION
Alzheimer's disease (AD) is a neurodegenerative disease. Neurons in the brain of Alzheimer's patients with
memory, thinking and learning function are destroyed, leading to dementia syndrome. Alzheimer's disease has
brought considerable harm to patients, which not only leads to a serious decline in the quality of life of patients,
depression and anxiety of patients' emotions, but also brings huge economic burden to patients and their families. It
has become the fourth largest killer of human health after cardiovascular disease, cancer and stroke. According to
the World Alzheimer's report, there are more than 50 million Alzheimer's patients worldwide, and the number of
Alzheimer's patients will reach 150 million by 2050, and the cost of caring for these patients is higher than the sum
of the first three diseases.
2. RELATED WORK
With the development of deep learning technology, more and more researchers use deep convolution models to
predict AD. These models are more effective than previous machine learning models, because convolution networks
can obtain more subtle features. Abrol et al.[12]used improved forms of deep residual neural networks to predict
progression from MCI to AD. The pre-trained model was first trained in AD and CN individuals and then used
transfer learning techniques to predict MCI.Their research showed that using three residual blocks works best, and
increasing the number of residual blocks does not improve the prediction effect. Venugopalan et al.[7]analyzed MRI,
3. METHODS
A CNN is used to extract disease-related feature information from MRI images to predict Alzheimer's disease.
To capture more detail information of MRI images, the VGG maximum pooling layer is replaced by the average
pooling layer. At the same time, the number of neurons in the fully connected layer is reduced, which not only
reduces the parameter quantity of the fully connected layer, but also inhibits the model over-fitting. In a cyclic
convolution layer, the results of convolution operation are added with the original image, thus enhancing the high-
order features of the original image.
This article selects MRI data samples from the The Alzheimer's Disease Neuroimaging Initiative (ADNI)
database. Specific selection criteria are :(1) MRI data form the ADNI-1 period; (2) Data selected are grouped into
three groups: CN,MCI,AD; (3) Select images with T1 weighted and MP-RAGE sequences; (4) The size of each
image is 256×166×256 pixels. Experimental dataset consist of MRI images from each subject in different years, so
the amount of data used is much larger than the number of samples. Table 1 shows the sample statistics for this
experiment.
This paper preprocesses the collected MRI images :(1) Categorize the data and send them to the SPM software
to match the scanned images of all subjects in the standard spatial coordinate system, and the specific processing
order is: head movement correction, registration, segmentation, standardization, smooth processing; (2) Separate
sections of gray matter, white matter and cerebrospinal fluid images along the axis after SPM treatment.Starting
with the slice at the 120th index position, and a total of 16 slices were taken from each MRI image; (3) Fusion of
gray matter, white matter and cerebrospinal fluid sections into a 3-channel image. Eventually, a total of 16128
fused images were generated in three categories, each with dimensions of 3×166×256 pixels. The sections selected
in this paper covered the lateral ventricle, lower temporal lobe, and middle temporal cortex, which were associated
with AD and MCI. As shown in figure 1, white matter, gray matter and cerebrospinal fluid sections were fused as
input to the model.
Figure 1. AD images of fused cerebrospinal fluid, white matter and gray matter
As shown in figure 2, the CNN model designed in this paper consists of three parts : (1) Cyclic convolution
layer adds the acquired feature maps to the original images; (2) The improved VGG is used as network to extract
feature information related to Alzheimer's disease; (3) Attention module automatically selects the most effective
feature for predicting disease by weight learning.
Backbone network: In this paper, the improved 11-layer VGG network is used as the backbone network to
extract features. In order to keep the local correlation information of the image as much as possible, the maximum
pool layer is adjusted to the average pool layer. In order to make the model more stable during training, batch
normalization layer is added after each convolution layer. In order to reduce the number of model parameters and
accelerate the training process, the number of fully connection layer output features is reduced. The backbone
network consists of eight convolution blocks(size=3×3, padding=1), and the average pool (size=2×2,stride=2) is
added to block 1,2,4 and 6, but the maximum pool of the last block is retained. Table 2 lists the details of the
network.
Table 2. Network configuration of Alzheimer's disease prediction model based on multiple attention mechanism
3×3 64 − d
Cyclic convolution 3×3 64 − d × 2
Block 1 3×3 3−d 3×166×256
layer
Stride=1,padding=1
FC 1 512 512×1
FC 2 96 96×1
FC 3 3 3×1
Attention module: The CBAM module is added in first layer convolution block, and the channel attention (CA)
and spatial attention (SA) are connected in series. The CBAM module is shown in figure 3, WHC represent the
width, height and number of channels of the feature map. The input feature map passes through the channel
attention module, the global pooling technology is used to compress the W×H feature map along the spatial
dimension to 1×1, while keeping the number of feature maps unchanged. The sigmoid function is used to calculate
the weight.
As shown in figure 2, the Non-Local Attention module (NL) is added after the fourth convolution block of the
backbone network, and the global connectivity is directly realized by strengthening the distance dependence through
NL. Local is mainly for the receptive field of convolution operation while the general receptive field is 3×3 and 5×5.
They only consider local regions, so they are all local operations. NL refers to the receptive field which is very large,
so it can see the overall situation. NL form is as follows:
1
yi = f(xi , xj )g(xj )
∁(x)
∀j
The i and j represent the spatial position of the input, function f calculates the similarity of any two points,
function g computes feature map representation at the j position.The y is finally obtained by standardizing response
factor ∁ x (Softmax) after processing.Using Embeded Gaussian as f function, 1×1 convolution as θ and ϕ.
Tϕ x
f(xi , xj ) = eθ xi j
∁ x = f(xi , xj )
∀j
After combining channel attention module, spatial attention module and global attention module, the model can
pay attention to the local and global features of the input image at the same time, thus enhancing the
representativeness of the features.
Cyclic convolution layer: By stacking the results of the two-layer convolution into the input data, strengthening the
significance features of the region, allowing the prediction network to capture the features faster.The display form of
cyclic convolution is as follows, which xt represents the input of the current moment,ht−1 represents the output of the
previous moment, and f represents the linear activation function.
ht = f xt + W ht−1 + xt−1
The experimental dataset is divided into train set and test set. The train set contains 10,000 images and the test
set contains 6,128 images. During training, 500 images are randomly selected from the train set each time as the
verification set. After each epoch, the accuracy of the model on the verification set is calculated, and the parameters
with the highest accuracy of the verification set are retained. The image shape of the input model is 3×166×256
pixels, batchsize is 50. After three epoch intervals, the performance of the model is calculated using the test set, and
stop the training process when the model falls into over-fitting. Stochastic gradient descent optimization algorithm
is used, the initial learning rate is 0.01, the momentum is 0.9. There may be some data that are particularly sensitive
to label changes, resulting in an abnormal increase in the loss function, which makes the model unstable. To
increase the stability of the model, the cross entropy loss and label smoothing regularization are used, it ensures the
generalization ability of the model and the data is not too sensitive to labels.
To evaluate the classification performance of the model, the classification accuracy (ACC), sensitivity (SEN),
specificity (SPE), F1-score were calculated in experiments. ACC is defined as the proportion of correctly classified
samples in the total number of samples. SEN is calculated according to the proportion of correctly classified
positive samples to the total number of positive samples. SPE is calculated according to the proportion of correctly
The dataset is randomly divided into independent train set and test set before the experiment begin. Using
Pytorch deep learning framework to build network.
This paper uses an 11-tier VGG network to extract MRI features. In the process of feature extraction, attention
module is added to strengthen the expression ability of the network. For the purpose of verifying the predictive
performance of the network, we compared the results of the original VGG network with the VGG Advanced
network and the VGG Attention network. As shown in table 5, the average Accuracy of the VGG Advanced
network is 99.6%, up about 0.4% compared with original VGG network. After increasing the attention mechanism,
Accuracy up about 0.2%. Finally, the average prediction Accuracy of our model is 99.8%, the other performance
indicators of the model are superior to the original VGG network.
Table 3. Experimental analysis of ablation
We chose the classic CNN model to do the model comparison experiment. VGG were made by Karen Simonyan
and got second place in the detection and classification task at the 2014 ImageNet Challenge, respectively. VGG use
small convolution kernels (3×3), whereas previous net using larger convolution kernels (11×11, 5×5 etc.).There are
two main meanings of using small convolution: one is to obtain the same receptive field with much smaller
calculation; and two layers of 3×3 convolution can introduce more non-linearity than one layer of 5×5 convolution,
so that the fitting ability of the model is stronger. The GoogLeNet was proposed by Christian S. et al., in addition to
the continued increase in network depth, multiple convolutions of different sizes used in the same layer are finally
superimposed. The ResNet proposed by Kaiming He used shortcut short connections to make the network very deep,
solved the problem that the increase of network depth will cause the gradient to disappear. ResNet proposed a
variety of network structures of different depths, such as ResNet-18、ResNet-50 and ResNet-101.
Table 4. Comparative experimental analysis of models
As shown in table 6, our proposed method achieved the best results in various performance indicators such as
Accuracy,Sensitivity,Specificity and F1-score. We also observed very low sensitivity to predictive AD using
AlexNet networks. GoogLeNet used inception modules to stack multiple convolution kernels of different sizes to
extract features and it is more complex than AlexNet and VGG. However, the experimental results show that it’s
prediction effect is limited. We can see that ResNet have a high specificity for AD prediction. Moreover, increasing
the depth of the network can significantly improve the performance (the ResNet-50 performance indicators are
better than ResNet-18). Meanwhile, We found that the depth and width of VGG are not as deep as GoogLeNet and
ResNet. Yet the performance of VGG is superior to GoogLeNet and ResNet-18 and close to ResNet-50. As shown
in Figure 6, after ninth iteration, the Accuracy and Sensitivity of ResNet-50 has declined significantly, and VGG
showed a small decline. ResNet-18 network fluctuated greatly during training. Our proposed model had steadily
improved performance during training, the indicators are consistently higher than other networks, indicating that our
proposed model identifies key features that can effectively and accurately predict AD.Meanwhile, compared with
the ResNet networks, the proposed network greatly improves the computational performance using a more concise
network structure, and finally achieves best prediction effect.
Table 7 shows that the experimental results of predicting AD using different methods.They all use MRI
scanning data under the ADNI database, but the subjects selected in the experiment are different. Comparing the
experimental results of different methods, we found that the proposed method is superior to the existing mainstream
methods in Accuracy,Sensitivity and Specificity.
Asl, et al. [22] used 3D-CNN and transfer learning techniques to predict AD. Finally ,97.6% accuracy was
obtained in CN/AD classification. Bumshik et al.[19] removed the noise pixels from the sMRI image intensity
As shown in Figure 7, as the number of training increases, Accuracy curve tends to 100% and Loss to 0. There
are 90% of the errors are due to the model's failure to correctly distinguish between MCI and AD. Because of the
unclear definition of MCI and the difference between the results of MRI scanning and the actual diagnosis of the
patient's condition.
This paper presents a prediction method of Alzheimer's disease based on multiple attention mechanisms. An AD
prediction model is established by using deep convolution neural network. And the model consists of three parts.
The first part use cyclic convolution to capture feature information. In the second part, an improved VGG network
is used to extract disease features and output feature vectors. And the third part use the extracted feature vector to
predict the AD. We add channel, spatial, and global attention mechanisms to the first and fourth layers of backbone
networks to capture both global and local pathological features in the image in favor of AD prediction. The results
of ablation experiments show that attention mechanism can improve the prediction performance of the network. In
order to objectively evaluate the performance of the proposed model, we compare the results of the proposed
method with the literature method using the same ADNI database. The results show that the performance of our
proposed model is superior to the existing mainstream methods:Accuracy, Sensitivity, Specificity and F1-score on
the test set are 99.8%,99.9%,99.8% and 0.996, respectively. In future research, We will combine multi-modal data
(sMRI 、PET 、 fMRI and neuropsychological cognitive evaluation scores, etc.) for AD prediction, and further
explore the interpretability of the AD prediction model.
ACKNOWLEDGMENTS
This work was supported in part by the National Natural Science Foundation of China under Grant 61372137,
in part by the Natural Science Foundation of Anhui Province under Grant 1908085MF209 and in part by the
Natural Science Foundation for the Higher Education Institutions of Anhui Province under Grant KJ2019A0036.
Huabin Wang is the corresponding author. E-mail address: wanghuabin@ahu.edu.cn.