Abstract
Deep learning approaches for clinical predictions based on magnetic resonance imaging data have shown great promise as a translational technology for diagnosis and prognosis in neurological disorders, but its clinical impact has been limited. This is partially attributed to the opaqueness of deep learning models, causing insufficient understanding of what underlies their decisions. To overcome this, we trained convolutional neural networks on structural brain scans to differentiate dementia patients from healthy controls, and applied layerwise relevance propagation to procure individual-level explanations of the model predictions. Through extensive validations we demonstrate that deviations recognized by the model corroborate existing knowledge of structural brain aberrations in dementia. By employing the explainable dementia classifier in a longitudinal dataset of patients with mild cognitive impairment, we show that the spatially rich explanations complement the model prediction when forecasting transition to dementia and help characterize the biological manifestation of disease in the individual brain. Overall, our work exemplifies the clinical potential of explainable artificial intelligence in precision medicine.
Similar content being viewed by others
Introduction
Since its invention in the 1970s, magnetic resonance imaging (MRI) has provided an opportunity to non-invasively examine the inside of the body. In neuroscience, images acquired with MRI scanners have been used to identify how the brains of patients with various neurological disorders differ from their healthy counterparts. Stereotypically, this has been done by collecting data from a group of patients with a given disorder and a comparable group of healthy controls, on which traditional statistical inference is applied to identify spatial locations of the brain where the groups differ1. Typically, these locations are not atomic locations identified by spatial coordinates, but rather morphological regions defined by an atlas, derived from empirical or theoretical insights of how the brain is structured. Differences between groups are described using morphometric properties like thickness or volume of these prespecified regions. A major benefit of this approach is the innate interpretability of the results: on average, patients with a given disorder deviate in a specific region of the brain in a comprehensible manner. Furthermore, the high degree of localization offered by modern brain scans allows for accurate characterization of where and how the brain of an individual deviates from an expected, typically healthy, norm2. However, the effects which are found are typically small3 with limited predictive power at the individual level4,5, which in turn has raised questions about whether these analytical methods are expressive enough to model complex mental or clinical phenomena6. As an alternative, new conceptual approaches are proposed, advocating modeling frameworks with increased expressive power that allow for group differences through complex, non-linear interactions between multiple, potentially distant, parts of the brain7, with a focus on prediction8. Such modeling flexibility is naturally achieved with artificial neural networks (ANNs), a class of statistical learning methods that combines aspects of data at multiple levels of abstraction, to accurately solve a predictive task9. However, while this often yields high predictive performance, e.g., by demonstrating clinically sufficient case-control classification accuracy for certain conditions, it comes at the cost of interpretation, as the models employ decision rules not trivially understandable by humans10. When the goal of the analysis is clinical, supporting the diagnosis and treatment of someone affected by a potential disorder, this opaqueness presents a substantial limitation. Thus, development and empirical validation of new methods within clinical neuroimaging that combine predictive efficacy with individual-level interpretability is imperative, to facilitate trust in how the system is working, and to accurately describe inter-individual heterogeneity.
With more than 55 million individuals afflicted worldwide11, over 25 million disability-adjusted life years lost12,13 and a cost exceeding one trillion USD yearly14, dementia is a prime example of a neurological disorders that incur a monumental global burden. Due to the global aging population the prevalence is expected to nearly triple by 205015, inciting a demand for technological solutions to facilitate handling the upcoming surge of patients. Dementia is a complex and progressive clinical condition16 with multiple causal determinants and moderators. Alzheimer’s disease (AD) is the most common form and accounts for 60–80% of all cases11. However, the brain pathologies underlying different subtypes of dementia are not disjoint, but often co-occur17,18,19, and have neuropathological commonalities20. The most prominent is neurodegeneration, occurring in both specific regions like the hippocampus, and globally across the brain21, and inter-individual variations in the localization of atrophy has been associated with impairments in specific cognitive domains22,23. Thus, the biological manifestation of dementia in the brain is heterogeneous24, resulting in distinctive cognitive and functional deficits20, highlighting the need for precise and personalized approaches to diagnosis. For patients with mild cognitive impairment (MCI), a potential clinical precursor to dementia, providing individualized characterizations of the underlying etiological disease at an early stage could widen the window for early interventions25, alleviate uncertainty about the condition, and help with planning for the future26.
In dementia, ANNs, and particularly convolutional neural networks (CNNs), have been applied to brain MRIs to differentiate patients from controls27,28, prognosticate outcomes29, and differentially diagnose subtypes30. However, while research utilizing this technology has been influential, clinical translations are scarce31. Where techniques for segmenting brain tumors or detecting lesions typically produce segmentation masks that are innately interpretable, predicting a complex diagnosis would entail compressing all information contained in a high-dimensional brain scan into a single number. Using deep learning, the decisions underlying this immense reduction are obfuscated, both from the developer of the system, the clinical personnel using it, and the patient ultimately impacted by the decision. This black box nature is broadly credited for the low levels of adoption in safety-critical domains like medicine32. Responding to this limitation, explainable artificial intelligence (XAI) provides methodology to explain the behavior of ANNs33. The nature of these explanations varies, e.g., by what type of model is to be explained, what conceptual level the explanation is at, and who it is tailored for34,35. In computer vision, XAI typically aims for post-hoc explanations of individual decisions, explaining why a model arrived at a given prediction for a given image. Explanations are often provided in a visual format, as a heatmap indicating how different regions of the image contribute to the prediction a The top half shows the prediction from the dementia model at each visit, while the bottom part displays the relevance map underlying the prediction. The opaque sections (including c, d, and e) contain information accessible at the imagined current timepoint (22.02.07) to support a clinician in a diagnostic procedure. The angle (\(\angle\)) represents the change in dementia prediction per year based on the first two visits. b Translucent regions reveal the morphological record for the remaining follow ups in the dataset, thus depicting the future. The ground truth diagnostic trajectory is encoded by the color of the markers. c Predicted probabilities of progression at future follow-ups based on the prediction and relevance map at the current timepoint. d Survival curve of the patient compared to the average MCI patient calculated from the prediction and relevance map. The marker indicates the location of the patient at the current timepoint. e A list of cognitive domains where the patient is predicted to significantly differ from the average based on the prediction and relevance map.
The regions with the highest density of relevance in our maps were the nucleus accumbens, amygdala and the parahippocampal gyrus, all of which are strongly affected in dementia52,53,54. While the two latter corroborate the established involvement of the medial temporal lobe55, it is surprising that the hippocampus does not appear in our analyses, as it has frequently in similar studies38,41,42. While this could be caused by actual localization of pathology56 we consider it more likely to be related to the internal machinery of the model. Specifically, the CNN relies on spatial context to identify brain regions before assessing their integrity, utilizing filters that span areas of the image larger than those containing the region itself. In the backwards pass, LRP uses these filters, and thus the localization of relevance is not necessarily voxel precise. Furthermore, we believe the model broadly can be seen as an atrophy detector, which necessarily entails looking for gaps surrounding regions instead of directly at the regions themselves. Therefore, while the relevance maps provide important information, they depend on contextual information and thus rely on interpretation from clinicians to maximize their utility in clinical practice.
We focused our analyses mainly on the relevance maps, but the results with largest, immediate, potential for clinical utility were the predictions from the dementia classifier. Other studies have shown the efficacy of machine learning models in differentiating dementia patients and healthy controls28, but it is intriguing that we see a large discrepancy in the predictions of the progressive and non-progressive MCI patients many years before the dementia diagnosis is given. This corroborates findings from theory-driven studies57 and a recent deep learning study27, implying detectable structural brain changes many years before the clinical diagnosis is given. This gives hope for advanced technology to contribute to early detection and diagnosis through MRI-based risk scores, in our case supported by a visual explanation. If curative treatments prove efficacious and become accessible, early identification of eligible patients could be imperative58. Furthermore, timely access to interventions have shown efficiency in slowing the progress of cognitive decline59, in addition to improving the quality of life for those afflicted and their caregivers26,60. Widely accessible technology that allows for early detection with high precision could play a key role in the collective response to the impending surge of patients and provide an early window of opportunity for more effective treatments.
While our results show a great potential for explainable AI, and particularly LRP, as a translational technology to detect and characterize dementia, there are limitations to our study. First, there are technical caveats to be aware of. Most importantly, there is an absolute dependence between the predictions of our model and the relevance maps. In our case, when we qualitatively assessed the relevance maps of the false negatives, they were indistinguishable from the true negatives. This emphasizes the fact that when the model is wrong, this is not evident from the explanations. Next, while the maps contain information sufficient to explain the prediction, they are not necessarily complete. Thus, they don’t contain all evidence in the MRI pointing towards a diagnosis, a property which could prove essential for personalization. We have addressed this problem through pragmatic solutions, namely ensembling and targeted augmentations, but theoretical development of the core methodology might be necessary to theoretically guarantee complete maps. Beyond the fundamental aspects of LRP, there are weaknesses to the present study that should be acknowledged. First, the dataset with dementia patients portrayed as heterogeneous mostly consists of ADNI and OASIS data, and thus patients with a probable AD diagnosis (although clinically determined). Thus, while we consider it likely, it is not necessarily true that the dimension of variability spanning from healthy controls to dementia patients portrayed by our model has the expressive power to extrapolate to other aetiologies. To overcome this in actual clinical implementations, we encourage the use of datasets that are organically collected from subsets of the population that are experiencing early cognitive impairments, for instance from memory clinics. Furthermore, it is not trivial to determine whether a clinical, broad, dementia-label is an ideal predictive target for models in clinical scenarios. Both ADNI and AIBL contain rich biomarker information with multiple variables known to be associated with dementia, such as amyloid positivity. It would be intriguing to see studies methodologically similar to ours with a biological predictive target, and we encourage investigations into whether this supports and complements the results we have observed here. Another limitation of the present study is out-of-sample generalization, especially related to scanners and acquisition protocols. Although we utilize data from many sites, which we have earlier shown to somewhat address this problem61, in combination with transfer learning, we did not explicitly test this by e.g., leaving sites out for validation. Again, we advise that clinical implementations should be based on realistic data, and thus at least be finetuned towards data coming from the relevant site, scanner, and protocol implemented in the clinic62. This also includes training models with class frequencies matching those observed in clinical settings, instead of naively balancing classes as we have done here. Next, we want to explicitly mention the cyclicality of our mask-and-predict validation. In a sense it trivially follows that regions that are considered important by a model are also the ones that are driving the predictions, and thus it is no surprise that the relevance maps coming from the dementia model are more important to the dementia model than the maps coming from e.g., the sex model. We addressed this by alternating the models for test and validation, but fully avoiding this circularity would require disjunct datasets, and more and larger cohorts. Finally, we highlight the potential drawbacks of including the improving MCI patients alongside the stable in the progression models. We believe this accurately depicts a realistic clinical scenario, where diagnostic and prognostic procedures happen based on currently available clinical information. However, that these patients improve could indicate that their condition is not caused by stable biological aberrations. This could oversimplify the subsequent predictive task, inflating our performance measures. In summary, the predictive value we observed for the individual patient must be interpreted with caution. However, our extensive validation approach as well as our thorough explanation of the method and its limitations, and training on large datasets, provide a first step towards making explainable AI relevant for clinical decision support in neurological disorders. Nonetheless, it also reveals a complicated balance between validating against existing knowledge and allowing for new discoveries. In our case, confirming whether small details revealed in the relevance maps are important aspects of individualization or simply intra-individual noise requires datasets with a label-resolution beyond what currently exists. Thus, we reiterate our belief that the continuation of our work should happen at the intersection between clinical practice and research63, by continuously collecting and labeling data to develop and validate technology in realistic settings.
To conclude, while there are still challenges to overcome, our study provides an empirical foundation and a roadmap for implementations of brain MRI based explainable AI in personalized clinical decision support systems. Specifically, we show that deep neural networks trained on a heterogenous set of brain MRI scans can predict dementia, and that their predictions can be made human interpretable. Furthermore, our pipeline allows us to reason about structural brain aberrations in individuals showing early signs of cognitive impairment by providing personalized characterizations which can subsequently be used for precise phenoty** and prognosis, thus fulfilling a realistic clinical purpose.
Methods
Data
The data used here was obtained from previously published, publicly accessible studies. All of these collected informed consents from their participants and received approval from their respective institutional review board or relevant research ethics committee. The present study was performed with approval from the Norwegian Regional Committees for Medical and Health Research Ethics (REK) and conducted in accordance with the Helsinki Declaration.
To train the dementia models we compiled a case-control dataset from seven different sources (Supplementary Table 1), consisting of patients with a dementia diagnosis and healthy controls from the same scanning sites. Because of the different diagnostic criteria used in the original datasets, we applied different rules to achieve a singular, heterogeneous dementia label (Supplementary Table 2). We extracted all participants with a dementia-diagnosis at all timepoints to comprise the patient group (n = 854). Then, for each unique proxy site (In ADNI, due to a large number of scanners and acquisition protocols, and the work put into unifying them, we used field strength as a proxy for site), sex, and age-bin spanning 10 years, we sampled an equal number of healthy controls to form the matched control set (total n = 1708, Table 1). Lastly, before modeling, we split the data into five equally sized folds stratified on diagnosis, site, sex, and age, such that all timepoints for a single participant resided in the same fold.
For the MCI dataset, we started with all participants from all ADNI waves with an MCI diagnosis (subjective memory complaint, MMSE between 24 and 30, CDR > 0.5 with memory box > 0.5, Weschler Memory Scale-Revised <9 for 16 years of education, <5 for 8–15 years of education and <3 for 0–7 years of education)64, on at least one timepoint. These were 12661 images from 6448 visits for 1256 participants, none of which were used for model training. This selection criterion ensured all participants had an MCI diagnosis at one point in time, though it did not limit us to only those timepoints. Thus, in addition to those with a consistent, stable, MCI diagnosis (sMCI), we had a variety of diagnostic trajectories, including those transitioning from normal cognition to MCI, MCI to AD (pMCI) and various other combinations. Before the subsequent analyses we discarded all participants without an MCI diagnosis initially, and everyone with ambiguous trajectories (e.g., MCI- > CN- > AD), leaving 5607 visits from 1138 participants.
From these two datasets, we extracted T1-weighted structural MRI data for each participant at each timepoint to use as inputs for the subsequent predictive models. Prior to modeling, the raw images were minimally processed using a previously developed pipeline58 relying on FreeSurfer v5.3 and FSL v6.065 to perform skullstrip**66 and linear registration to MNI152-space67 with six degrees of freedom. Consequently, the processed images consisted of normalized voxel values from the raw images, registered to a common spatial template and contained minimal non-brain tissue.
Modeling
All dementia models were variants of the PAC2019-winning simple fully convolutional network architecture68,69, modified to have a single output neuron with a sigmoid activation. The architecture is a simple, VGG-like convolutional neural network with six convolutional blocks and ~3 million parameters. We initialized the model with weights from a publicly accessible brain age model previously shown to have superior generalization capabilities when dealing with unseen scanning sites and protocols61. The models were trained on a single Nvidia A100 GPU with 40 GB of memory, Tensorflow 2.670 through the Keras interface71. We used a vanilla stochastic gradient descent (SGD) optimizer with a learning rate defined by the hyperparameter settings (see next section), optimizing the binary cross-entropy loss. All models ran for 160 epochs with a batch size of 6, and for each run the epoch with the lowest validation loss was chosen. Varying slightly depending on the hyperparameters, a single model trained in ~4 h.
For each hold-out test fold we trained models on three of the remaining folds and validated on the fourth, akin to a cross-validation with an additional out-of-sample test set, to achieve out-of-sample predictions for all 1708 participants while allowing for hyperparameter tuning. The hyperparameters we optimized were dropout \(d\,\in \left\{0.25,\,0.5\right\}\) and weight decay \(w\,\in \left\{{10}^{-2},\,{10}^{-3}\right\}\). Additionally, we tested stepwise, one-cycle and multi-cycle learning rate schedules and a light and a heavy augmenter. Initial values for the learning rate were set manually based on a learning rate sweep72, though kept conservative to preserve the learned features from the pretraining. The hyperparameter search was implemented as a naive grid-search over the total 24 different configurations (Supplementary Fig. 9). We selected the model procuring the best AUC in the validation set to produce out-of-sample predictions for the outer hold-out fold. In the final evaluation of the models, we compiled predictions for all participants, for each using the model where they belonged to the hold-out test set. Our main method for measuring performance was the AUC, but we also report accuracy, which, due to our matching procedure, is equivalent to balanced accuracy.
Relevance maps
We built a pipeline \({LR}{P}_{{dementia}}\) for generating relevance maps by implementing LRP37 on top of the trained classifier. LRP is a technique for explaining single decisions made by the model, and thus, when running the pipeline on input \(X\) a relevance map \(R\) is generated alongside the prediction \(\hat{y}\). \(R\) is a three-dimensional volume, representing a visual explanation for \(\hat{y}\), where each voxel \({r}_{i,j,k}\,\in R\) has a spatial position \(i,j,k\) corresponding to the location of an input voxel \({x}_{i,j,k}\,\in X\). Furthermore, the intensity of \({r}_{i,j,k}\) can be directly interpreted as how much voxel \({x}_{i,j,k}\) contributes to \(\hat{y}\), such that \(\sum _{r\in R}r=\,\hat{y}\). In the original LRP-formulation, relevance \(r\) is propagated backwards between subsequent layers \({Z}_{l}\) and \({Z}_{l+1}\) with artificial neurons \({a}_{m}\in {Z}_{l}\) and \({a}_{n}\in {Z}_{l+1}\) such that \(r({a}_{m})\) is proportional to how much \({a}_{m}\) contributes to the activations of all \({a}_{n}\) in the forward pass (Eq. (1)).
where \({w}_{{mn}}\) denotes the weight between \({a}_{m}\) and \({a}_{n}\)
We controlled the influence of different aspects of the explanations using a composite LRP strategy43, combining different formulations of the LRP formula for the different layers in the model to enhance specific aspects of the relevance maps. Specifically, we employed a combination of alpha-beta and epsilon rules that have previously shown to produce meaningful results for dementia classifiers41,42. For the prediction layer, we retained the most salient explanations through an \({LR}{P}_{\epsilon }\)-rule (Eq. (2)).
For the central convolutional layers, we upweighted positive relevance (e.g., features increasing the prediction, corresponding to evidence for a diagnosis) with \({LR}{P}_{\alpha \beta }\)-rules (Eq. (3)).
where \({\left(\cdot \right)}^{+}\) and \({\left(\cdot \right)}^{-}\) denote positive and negative contributions respectively
For the input layer and the subsequent convolutional layer, we employed \({LR}{P}_{b}\) to smooth finer details of the relevance maps (Eq. (4)).
where \({\boldsymbol{|o|}}\) denotes the number of nodes connected to \({{\boldsymbol{a}}}_{{\boldsymbol{n}}}\).
The resulting relevance maps produced by the pipeline were full brain volumes with the same dimensionality as the MRI data (167 × 212 × 160 voxels) containing mostly (see below) positive relevance.
Notation-wise we generally consider the relevance map \(R\left(X\right)\) for an image \(X\) to be a function of the model \({m}_{{task}}\), where \({task}\) indicates which task the model was trained for, the LRP strategy \(\text{LR}{\text{P}}_{\text{composite}}\) and the image \(X\) (Eq. (5)).
Because the composite LRP strategy described above is kept fixed in our pipeline, this can be contracted (Eq. (6)).
Furthermore, the model-specifier \({task}\) can also annotate the map for a further simplification (Eq. (7)).
Thus, \({LR}{P}_{{task}}\) is used to annotate the full pipeline for a given task, while \({R}_{{task}}\left(X\right)\) denotes a single relevance map generated by this pipeline for image \(X\). When the task is given by the context, we sometimes simplify this further to \(R\left(X\right)\), and when a general image is considered, we simply use \(R\) to denote its relevance map.
While we generally discuss our pipeline as a singular one, there were in reality five approximately equivalent pipelines (corresponding to the models trained for the five test folds), and which one is used depends on what image was used as input. Specifically, for each participant diagnosed with dementia, the pipeline is chosen where the participant was part of the hold-out test set while training the model, and both the relevance maps and the predictions are thus always out-of-sample. For participants used in the MCI analysis, which are all out-of-sample for all models, we created an ensemble by averaging the predictions and the voxel-wise relevance across all models.
Before implementing the LRP procedure we made two slight modifications to the models to facilitate the backwards relevance propagation, both leaving the functional interface of the model unchanged. First, we removed the sigmoid activation in the final layer, so that the output of the model changed from a bounded continuous variable \(\hat{y}\,\in \left[0,\,1\right]\) to an unbounded prediction \({\hat{{\rm{y}}}}_{\sigma }\,\in \left[-\infty ,\,\infty \right]\). In this space a raw prediction of \({\hat{{\rm{y}}}}_{\sigma }\,=0\) is equivalent to a sigmoid-transformed prediction of \(\hat{y}=0.5\), and thus \({\hat{{\rm{y}}}}_{\sigma }\, < \,0\) means that the model predicts control status for the given participant, and oppositely \({\hat{{\rm{y}}}}_{\sigma }\, > \,0\) implies that the model predicts a dementia diagnosis. Furthermore, this means that all positive relevance \(r\,\in R,\,r\, > \,0\) can be interpreted as visual evidence in favor of a dementia diagnosis. Secondly, we modified the model by fusing all batch normalization layers with their preceding convolutional layers, adjusting their weights and biases to match the shift and scaling previously performed by the normalization layer8)).
All visualized relevance maps are plotted after non-linear registration, overlayed on the MNI152-template. As the maps are three-dimensional, we generally plot a collection of distributed axial slices. The relevance is colored by the nibabel v3.2.275 cold_hot colourmap. Since the absolute relevance values vary between maps, all maps are normalized to the intensity range [0, 1] in the visualizations.
Validating the relevance maps
Earlier studies have shown that interpretability techniques in general are prone to generate visual explanations that do not capture salient parts of the input50,51. To investigate the extent of this for our pipeline \({LR}{P}_{{dementia}}\) we employed two analyses to assess the sanity of the relevance maps. The first was an established task-specific technique comparing the relevance maps to existing knowledge of the pathology of dementia40. The second was a purely quantitative analysis examining how important the regions found by the pipeline are for the dementia prediction \(\hat{y}\). In both cases we contrasted the relevance maps generated from the main pipeline with three alternative pipelines representing variants of a null hypothesis, all expected to produce relevance maps with no significant association with dementia.
\({LR}{P}_{{random}\,{images}}\) represents the simplest alternative pipeline, and is built around the dementia model, but with an additional preprocessing step scrambling the input (Eq. (9)).
where \({\mathcal{X}}={\mathcal{N}}\left(\bar{X},\,{\sigma }_{X}\right)\)
\({LR}{P}_{{random}\,{images}}\) is expected to generate relevance maps where the relevance is evenly distributed across the entire image. In the next pipeline \({LR}{P}_{{random}\,{weights}}\) we replaced the dementia-model with a model with random weights (Eq. (10)).
\({m}_{\theta }\) has not been trained for any task, and thus has random weights initialized by the default Keras ”Glorot Uniform” weight-initializer. This pipeline is expected to produce relevance maps which correlate with the raw voxel intensities, e.g., high intensity in the input should entail more (absolute) relevance, thereby reflecting aspects of morphology. The final and most realistic alternative pipeline was \({LR}{P}_{{sex}}\), where we replaced the dementia-model with a binary sex-classifier (Eq. (11)).
The sex-classifier was trained to differentiate males from females in one of the splits from the dementia-dataset, achieving an out-of-sample AUC of 0.956 and a balanced accuracy of 89.40%. We did not do any hyperparameter optimization for this model but used the best configuration from the dementia cross-validation in the same fold. The heatmaps from this pipeline should reflect regions where there is intra-individual variation in morphology, which are predictive of sex but with minimal association with dementia.
As a proxy for existing knowledge in the literature, we performed an ALE meta-analysis using Sleuth v3.0.476 and GingerALE v3.0.244. We used Sleuth to search for relevant articles with the query
in the Voxel-based morphometry database, yielding 394 experiments from 124 articles. These experiments contained 3972 foci, 280 of which were outside the MNI152 mask, leaving 3692 to be loaded into GingerALE. Then the reference map \(G\), with voxels \({g}_{i,j,k}\), was generated by an ALE meta-analysis using the default parameters: Cluster-level FWE = 0.01, Threshold Permutations = 1000, P value = 0.001. The reference map is visualized in Supplementary Fig. 4.
We performed four pairwise comparisons to estimate the amount of overlap between each of the pipelines and \(G\). For each pipeline the comparison was performed by computing an average map \(\bar{R}\), binarizing both it and \(G\), and computing the Dice overlap between the two. The employed approach closely resembles the method of Wang et al.40, but with multiple thresholds of binarization also for \(G\), and allowed us to plot similarity as a function of the threshold. For each pipeline, we first computed an average relevance map \(\bar{R}\) across all true positives (e.g., dementia patients that were correctly predicted to have a diagnosis by the dementia-model, n = 697), by computing their voxel-wise average. Next, we binarized both the average map (Eq. (12)) and the reference map (Eq. (13)) by thresholding them at multiple percentiles \(p\in \left[0,\,100\right)\).
Then, for each percentile \(p\) we calculate the Sørensen-Dice coefficient \({SD}{C}_{p}\) between the two (Eq. (14)).
Additionally, to have a singular numerical basis for comparison, we computed the normalized cross-correlation45 between the (non-binarized) average maps \(\bar{R}\) and the reference map \({G}\) (Eq. (15)).
To facilitate an intuitive understanding of what parts of the brain the dementia-model is focusing on, we also performed a similar, region-wise comparison. This was done by extracting a subset of voxels from the average relevance map \({\bar{R}}_{{dementia}}\) belonging to each region \(\rho\) (Eq. (16)) from the Harvard-Oxford cortical and subcortical atlases77.
We did the same for \(G\) and let the mean activation per region for both constitute a tuple (Eq. (17)) plotted in Fig. 2c.
However, as it is non-trivial to determine which aggregation method corresponds to the most understandable and intuitive interpretation, we also created plots for tuples of sums (Eq. (18)) and maximum values (Eq. (19)) per region in Supplementary Fig. 10.
To quantify the importance of the spatial locations captured by the various LRP pipelines for predicting dementia, we implemented a procedure for iteratively occluding parts of the image based on the relevance maps and observing how the prediction from the dementia model changed78. Still using the true positives, for each pipeline \({LR}{P}_{{task}}\) for each MRI \({X}_{0}\) we generated a baseline dementia-prediction \({\hat{{\rm{y}}}}_{0}\) and relevance map \({R}_{{task}}\). Then we located the voxel with the highest amount of relevance in \({R}_{{task}}\) and replaced a 15 × 15 × 15 cube centered around the voxel with random uniform noise \({\mathcal{U}}\left(0,\,1\right)\), effectively concealing all information contained in this region. Next, we ran the modified image \({X}_{{task}}^{1}\) through the dementia-model to see how the prediction \({\hat{y}}_{{task}}^{1}\) changed as a function of the occlusion. Note that injecting a box of random noise into the image is not trivially equivalent to removing information, however we specifically applied the same modification in the random box-augmentation during training and are thus hopeful that the model is invariant to the injection beyond the information removal. We iteratively applied this modify-and-predict procedure, also masking out the regions from the relevant maps between each iteration to minimize overlap of occlusion windows, for 20 iterations, producing a list of predictions \(\left[{\hat{{\rm{y}}}}_{0},\,{\hat{y}}_{{task}}^{1},{\hat{y}}_{{task}}^{2},\,\ldots ,{\hat{y}}_{{task}}^{19}\right]\) plotted for each pipeline in Fig. 2d (averaged across all true positives). The rate of decline in these traces indicates the importance of the regions found in the respective relevance maps. We quantified the differences between the pipelines \({LR}{P}_{{task}}\) by calculating the area over their perturbation curves78 (AOPCs, Eq. (20)).
Exploratory analyses in the MCI cohort
In the exploratory MCI analyses, we used \({LR}{P}_{{dementia}}\) to generate predictions and relevance maps for participants from ADNI who were given an MCI diagnosis at inclusion. We first compiled the predictions and relevance maps (and the corresponding timestamps) for each participant at all timepoints into a single data structure we called a morphological record. We then tried to utilize this data structure to differentiate three groups: stable MCI patients (sMCI), progressive MCI patients (pMCI), and those who saw improvement in their cognition throughout the data collection phase. The remaining participants, e.g., those who either passed through all three diagnostic stages, or bounced between diagnoses, were excluded. Furthermore, we combined the stable and improving cohorts into a non-progressive group (nMCI) to facilitate binary group comparisons in the subsequent analyses.
For the first analysis comparing predictions in the two groups, due to variability in the total number and the frequency of visits between participants, we aimed to create a matched dataset based on visit history from the nMCI and pMCI cohorts to compare the predictions in the two groups with reference to a specific timepoint. We first started with all the progressive patients \({p}_{p}\,\in {pMCI}\) who got a diagnosis at timepoint \({t}_{n+1}\), and, for each patient individually, compiled all previous visits \({t}_{m},\,m\,\le \,n\) into a vector \({h}_{p}\) representing the time of the visits. The entries \({d}_{{t}_{m}}\) of the vector were the number of days until the diagnosis was given, \({t}_{n+1}-{t}_{m}\), including \({d}_{{t}_{n+1}}=0\) (Eq. (21)).
Then, for each of the non-progressive patients \({p}_{n}\,\in {nMCI}\) who did not have a time of diagnosis (e.g., \({t}_{n+1}\) is not given) we compiled a set \({H}_{p}\) of all possible history vectors \({h}_{p}\) by varying which visit was chosen as \({t}_{0}\) and a terminal non-diagnosis timepoint \({t}_{n+1}\). Next, we defined a cost-criterion for matching two histories (with an equal number of visits) as the sum of absolute pairwise differences between the vectors (Eq. (22)).
For each pair of progressive and non-progressive patients \(({p}_{p},{p}_{n})\) this allowed us to calculate a best possible match (Eq. (23)), given that the stable patient had a total number of visits equal to or larger than the number of visits for the progressive patient.
Finally, we compiled the cost of the optimal match from all pairs into a matrix and found the best complete matching by minimizing the total cost across this matrix using the Hungarian algorithm implemented in scipy v1.6.379, such that each patient occurs in at most one pair.
We estimated differences in predictions \(\hat{y}\) between the two groups using a linear mixed model. Specifically, we modeled \(\hat{y}\) at all timepoints before the terminal timepoint \({t}_{n+1}\) as a function of age, sex (as controlling variables), years to diagnosis, categorical group membership (nMCI, pMCI), and an interaction between years to diagnosis and group. In addition, we had an independent intercept and slope per participant. The model was fit through the formula API of statsmodels v0.13.280 using the formula from Eq. (24) on the matched dataset.
A full overview of coefficients and p values can be found in Supplementary Table 4.
Due to the high dimensionality of the relevance maps, we decomposed them with a principal component analysis (PCA) before the final analyses. To fit the PCA we used the non-linearly registered relevance maps from a randomly selected timepoint for all MCI patients. Before fitting the model, all relevance maps were smoothed with a constant 3 × 3 × 3 blurring kernel using the convolution operation from Tensorflow 2.6 to strengthen the signal-to-noise ratio. The PCA was computed using scikit-learn v1.0.281, retaining 64 components (out of 1137 maximally possible) in a component vector \(c\,=\,\left[{c}_{0},\,{c}_{1},\,\ldots ,\,{c}_{63}\right]\). An axial slice from each of the 64 components visualized in MNI152-space is shown in Supplementary Fig. 6.
We fit Cox proportional hazard models using the component vectors as predictors to assess the association between the relevance maps and progression as a function of age. In addition to the components, representing the maps, we controlled for sex in the model. The p values and coefficients can be found in Supplementary Table 5. To account for covariance between the components and the dementia-prediction \(\hat{y}\) we ran an additional model where we divided the patients into ten strata based on \(\hat{y}\). Both models were fit using lifelines v0.27.182.
To further explore the prognostic efficacy of our pipeline we set up a predictive analysis for predicting progression at multiple, fixed timepoints a given number of months in the future. For each participant \(p\) with visits at timepoints \({t}^{p}\), we denoted the last timepoint with an MCI diagnosis \({t}_{{neg}}^{p}\) and the first timepoint with a dementia diagnosis (if present) \({t}_{{pos}}^{p}\). Using a fixed set of years into the future, \(\gamma \,\in \left\{1,\,2,\,3,\,4,\,5\right\}\), we constructed a target variable \({z}_{\gamma }\left({t}^{p}\right)\) encoding progression according to Eq. (25).
where the NAs allow for exclusion of all patients where the status at timepoint \({t}^{p}+\gamma\) is unknown. For each \(\gamma\) we constructed the target vector \({z}_{\gamma \,}\) across all timepoints for all participants with \({z}_{\gamma \,}\ne {NA}\) and split the constituent patients \(p\) into five folds stratified on \({z}_{\gamma }\), sex and age, such that all timepoints from a participant resided in the same fold. Using these folds, we fit logistic regression models to predict \({z}_{\gamma }\) with an \({l}_{1}\)-penalty in a nested cross-validation loop, allowing us to both tune the regularization parameter \(\lambda\) and have out-of-sample predictions for all participants. For eligible participants we used all timepoints for training the models, but during testing we sampled a random timepoint per participant to ensure independence between datapoints in the final evaluation. For each \(\gamma\) we fit three models: a baseline model to assess the bias in the dataset with respect to age at the given timepoint \({t}^{p}\) and sex (Eq. (26)), a model including the prediction \({\hat{y}}_{{t}^{p}}\) from the dementia classifier at \({t}^{p}\) as a predictor (Eq. (27)), and a model including the relevance maps from \({t}^{p}\), represented by the component vector \({c}_{{t}^{p}}\), as additional predictors (Eq. (28)).
All models were fit and tuned using the LogisticRegressionCV interface of sklearn v1.0.281. We compared models by measuring the mean AUC across the five folds (Supplementary Table 6). To evaluate clinical applicability we also report accuracy, positive predictive value, sensitivity, and specificity (Table 2). To determine whether the more complex models represented significant improvements we employed a one-sided Wilcoxon signed-rank test from scipy v1.9.379 to do pairwise comparisons between \({{\mathcal{M}}}_{{base}}\) and \({{\mathcal{M}}}_{{pred}}\), and \({{\mathcal{M}}}_{{pred}}\), and \({{\mathcal{M}}}_{{comp}}\) across the five out-of-sample AUCs independently.
To assess whether the relevance maps were associated with specific cognitive functions we associated aspects of them with performance on various cognitive tests. We first extracted test results from seven neuropsychological batteries which spanned all ADNI waves and contained high-level summary scores from the ADNI website (Supplementary Table 7). We then manually extracted 17 summary scores spanning different, but overlap**, cognitive domains (Supplementary Table 8). The component vectors \(c\) were used as proxies for the relevance maps, where each represented a template for localization of pathology. We matched 2402 component vectors with test results from 733 MCI patients, forming a basis for the comparison. We then calculated the univariate association between cognitive performance according to each of the 17 with each of the dimensions \({c}_{i}\,\in c\), while including age and sex as covariates for correction. To isolate the effect of the localization we also corrected for dementia-prediction, \(\hat{y}\). When a patient had multiple potential matches, a random timepoint was selected, and the final number of datapoints used in the analyses varied from 518 to 675. Correction for multiple testing was done with the Benjamini-Hochberg procedure. To ensure the associations were not confounded by collinearities between \(c\) and \(\hat{y}\), we also performed an equivalent analysis without correction to observe whether the sign of the coefficients changed.
Reporting summary
Further information on research design is available in the Nature Research Reporting Summary linked to this article.