Introduction

Since its invention in the 1970s, magnetic resonance imaging (MRI) has provided an opportunity to non-invasively examine the inside of the body. In neuroscience, images acquired with MRI scanners have been used to identify how the brains of patients with various neurological disorders differ from their healthy counterparts. Stereotypically, this has been done by collecting data from a group of patients with a given disorder and a comparable group of healthy controls, on which traditional statistical inference is applied to identify spatial locations of the brain where the groups differ1. Typically, these locations are not atomic locations identified by spatial coordinates, but rather morphological regions defined by an atlas, derived from empirical or theoretical insights of how the brain is structured. Differences between groups are described using morphometric properties like thickness or volume of these prespecified regions. A major benefit of this approach is the innate interpretability of the results: on average, patients with a given disorder deviate in a specific region of the brain in a comprehensible manner. Furthermore, the high degree of localization offered by modern brain scans allows for accurate characterization of where and how the brain of an individual deviates from an expected, typically healthy, norm2. However, the effects which are found are typically small3 with limited predictive power at the individual level4,5, which in turn has raised questions about whether these analytical methods are expressive enough to model complex mental or clinical phenomena6. As an alternative, new conceptual approaches are proposed, advocating modeling frameworks with increased expressive power that allow for group differences through complex, non-linear interactions between multiple, potentially distant, parts of the brain7, with a focus on prediction8. Such modeling flexibility is naturally achieved with artificial neural networks (ANNs), a class of statistical learning methods that combines aspects of data at multiple levels of abstraction, to accurately solve a predictive task9. However, while this often yields high predictive performance, e.g., by demonstrating clinically sufficient case-control classification accuracy for certain conditions, it comes at the cost of interpretation, as the models employ decision rules not trivially understandable by humans10. When the goal of the analysis is clinical, supporting the diagnosis and treatment of someone affected by a potential disorder, this opaqueness presents a substantial limitation. Thus, development and empirical validation of new methods within clinical neuroimaging that combine predictive efficacy with individual-level interpretability is imperative, to facilitate trust in how the system is working, and to accurately describe inter-individual heterogeneity.

With more than 55 million individuals afflicted worldwide11, over 25 million disability-adjusted life years lost12,13 and a cost exceeding one trillion USD yearly14, dementia is a prime example of a neurological disorders that incur a monumental global burden. Due to the global aging population the prevalence is expected to nearly triple by 205015, inciting a demand for technological solutions to facilitate handling the upcoming surge of patients. Dementia is a complex and progressive clinical condition16 with multiple causal determinants and moderators. Alzheimer’s disease (AD) is the most common form and accounts for 60–80% of all cases11. However, the brain pathologies underlying different subtypes of dementia are not disjoint, but often co-occur17,18,19, and have neuropathological commonalities20. The most prominent is neurodegeneration, occurring in both specific regions like the hippocampus, and globally across the brain21, and inter-individual variations in the localization of atrophy has been associated with impairments in specific cognitive domains22,23. Thus, the biological manifestation of dementia in the brain is heterogeneous24, resulting in distinctive cognitive and functional deficits20, highlighting the need for precise and personalized approaches to diagnosis. For patients with mild cognitive impairment (MCI), a potential clinical precursor to dementia, providing individualized characterizations of the underlying etiological disease at an early stage could widen the window for early interventions25, alleviate uncertainty about the condition, and help with planning for the future26.

In dementia, ANNs, and particularly convolutional neural networks (CNNs), have been applied to brain MRIs to differentiate patients from controls27,28, prognosticate outcomes29, and differentially diagnose subtypes30. However, while research utilizing this technology has been influential, clinical translations are scarce31. Where techniques for segmenting brain tumors or detecting lesions typically produce segmentation masks that are innately interpretable, predicting a complex diagnosis would entail compressing all information contained in a high-dimensional brain scan into a single number. Using deep learning, the decisions underlying this immense reduction are obfuscated, both from the developer of the system, the clinical personnel using it, and the patient ultimately impacted by the decision. This black box nature is broadly credited for the low levels of adoption in safety-critical domains like medicine32. Responding to this limitation, explainable artificial intelligence (XAI) provides methodology to explain the behavior of ANNs33. The nature of these explanations varies, e.g., by what type of model is to be explained, what conceptual level the explanation is at, and who it is tailored for34,35. In computer vision, XAI typically aims for post-hoc explanations of individual decisions, explaining why a model arrived at a given prediction for a given image. Explanations are often provided in a visual format, as a heatmap indicating how different regions of the image contribute to the prediction

Fig. 4: A visualization of the proposed morphological record for a randomly selected progressive MCI patient that was held out of all models and analyses.
figure 4

a The top half shows the prediction from the dementia model at each visit, while the bottom part displays the relevance map underlying the prediction. The opaque sections (including c, d, and e) contain information accessible at the imagined current timepoint (22.02.07) to support a clinician in a diagnostic procedure. The angle (\(\angle\)) represents the change in dementia prediction per year based on the first two visits. b Translucent regions reveal the morphological record for the remaining follow ups in the dataset, thus depicting the future. The ground truth diagnostic trajectory is encoded by the color of the markers. c Predicted probabilities of progression at future follow-ups based on the prediction and relevance map at the current timepoint. d Survival curve of the patient compared to the average MCI patient calculated from the prediction and relevance map. The marker indicates the location of the patient at the current timepoint. e A list of cognitive domains where the patient is predicted to significantly differ from the average based on the prediction and relevance map.

The regions with the highest density of relevance in our maps were the nucleus accumbens, amygdala and the parahippocampal gyrus, all of which are strongly affected in dementia52,53,54. While the two latter corroborate the established involvement of the medial temporal lobe55, it is surprising that the hippocampus does not appear in our analyses, as it has frequently in similar studies38,41,42. While this could be caused by actual localization of pathology56 we consider it more likely to be related to the internal machinery of the model. Specifically, the CNN relies on spatial context to identify brain regions before assessing their integrity, utilizing filters that span areas of the image larger than those containing the region itself. In the backwards pass, LRP uses these filters, and thus the localization of relevance is not necessarily voxel precise. Furthermore, we believe the model broadly can be seen as an atrophy detector, which necessarily entails looking for gaps surrounding regions instead of directly at the regions themselves. Therefore, while the relevance maps provide important information, they depend on contextual information and thus rely on interpretation from clinicians to maximize their utility in clinical practice.

We focused our analyses mainly on the relevance maps, but the results with largest, immediate, potential for clinical utility were the predictions from the dementia classifier. Other studies have shown the efficacy of machine learning models in differentiating dementia patients and healthy controls28, but it is intriguing that we see a large discrepancy in the predictions of the progressive and non-progressive MCI patients many years before the dementia diagnosis is given. This corroborates findings from theory-driven studies57 and a recent deep learning study27, implying detectable structural brain changes many years before the clinical diagnosis is given. This gives hope for advanced technology to contribute to early detection and diagnosis through MRI-based risk scores, in our case supported by a visual explanation. If curative treatments prove efficacious and become accessible, early identification of eligible patients could be imperative58. Furthermore, timely access to interventions have shown efficiency in slowing the progress of cognitive decline59, in addition to improving the quality of life for those afflicted and their caregivers26,60. Widely accessible technology that allows for early detection with high precision could play a key role in the collective response to the impending surge of patients and provide an early window of opportunity for more effective treatments.

While our results show a great potential for explainable AI, and particularly LRP, as a translational technology to detect and characterize dementia, there are limitations to our study. First, there are technical caveats to be aware of. Most importantly, there is an absolute dependence between the predictions of our model and the relevance maps. In our case, when we qualitatively assessed the relevance maps of the false negatives, they were indistinguishable from the true negatives. This emphasizes the fact that when the model is wrong, this is not evident from the explanations. Next, while the maps contain information sufficient to explain the prediction, they are not necessarily complete. Thus, they don’t contain all evidence in the MRI pointing towards a diagnosis, a property which could prove essential for personalization. We have addressed this problem through pragmatic solutions, namely ensembling and targeted augmentations, but theoretical development of the core methodology might be necessary to theoretically guarantee complete maps. Beyond the fundamental aspects of LRP, there are weaknesses to the present study that should be acknowledged. First, the dataset with dementia patients portrayed as heterogeneous mostly consists of ADNI and OASIS data, and thus patients with a probable AD diagnosis (although clinically determined). Thus, while we consider it likely, it is not necessarily true that the dimension of variability spanning from healthy controls to dementia patients portrayed by our model has the expressive power to extrapolate to other aetiologies. To overcome this in actual clinical implementations, we encourage the use of datasets that are organically collected from subsets of the population that are experiencing early cognitive impairments, for instance from memory clinics. Furthermore, it is not trivial to determine whether a clinical, broad, dementia-label is an ideal predictive target for models in clinical scenarios. Both ADNI and AIBL contain rich biomarker information with multiple variables known to be associated with dementia, such as amyloid positivity. It would be intriguing to see studies methodologically similar to ours with a biological predictive target, and we encourage investigations into whether this supports and complements the results we have observed here. Another limitation of the present study is out-of-sample generalization, especially related to scanners and acquisition protocols. Although we utilize data from many sites, which we have earlier shown to somewhat address this problem61, in combination with transfer learning, we did not explicitly test this by e.g., leaving sites out for validation. Again, we advise that clinical implementations should be based on realistic data, and thus at least be finetuned towards data coming from the relevant site, scanner, and protocol implemented in the clinic62. This also includes training models with class frequencies matching those observed in clinical settings, instead of naively balancing classes as we have done here. Next, we want to explicitly mention the cyclicality of our mask-and-predict validation. In a sense it trivially follows that regions that are considered important by a model are also the ones that are driving the predictions, and thus it is no surprise that the relevance maps coming from the dementia model are more important to the dementia model than the maps coming from e.g., the sex model. We addressed this by alternating the models for test and validation, but fully avoiding this circularity would require disjunct datasets, and more and larger cohorts. Finally, we highlight the potential drawbacks of including the improving MCI patients alongside the stable in the progression models. We believe this accurately depicts a realistic clinical scenario, where diagnostic and prognostic procedures happen based on currently available clinical information. However, that these patients improve could indicate that their condition is not caused by stable biological aberrations. This could oversimplify the subsequent predictive task, inflating our performance measures. In summary, the predictive value we observed for the individual patient must be interpreted with caution. However, our extensive validation approach as well as our thorough explanation of the method and its limitations, and training on large datasets, provide a first step towards making explainable AI relevant for clinical decision support in neurological disorders. Nonetheless, it also reveals a complicated balance between validating against existing knowledge and allowing for new discoveries. In our case, confirming whether small details revealed in the relevance maps are important aspects of individualization or simply intra-individual noise requires datasets with a label-resolution beyond what currently exists. Thus, we reiterate our belief that the continuation of our work should happen at the intersection between clinical practice and research63, by continuously collecting and labeling data to develop and validate technology in realistic settings.

To conclude, while there are still challenges to overcome, our study provides an empirical foundation and a roadmap for implementations of brain MRI based explainable AI in personalized clinical decision support systems. Specifically, we show that deep neural networks trained on a heterogenous set of brain MRI scans can predict dementia, and that their predictions can be made human interpretable. Furthermore, our pipeline allows us to reason about structural brain aberrations in individuals showing early signs of cognitive impairment by providing personalized characterizations which can subsequently be used for precise phenoty** and prognosis, thus fulfilling a realistic clinical purpose.