Introduction

A fundamental challenge in medical studies is to accurately model confounding variables1,2,3. Confounders are extraneous variables that distort the apparent relationship between input (independent) and output (dependent) variables and hence lead to erroneous conclusions4,5,6 (see Fig. 1). For instance, when neuroimaging studies aim to distinguish healthy individuals (a.k.a controls) from subjects impacted by a neurological disease, the input variables are images or image-derived features, and the output variables are the class labels (i.e., diagnosis). If the average age of the diseased cohort is significantly older than the healthy controls, the age of individuals potentially confounds the study7,8,9. When not properly modeled, a predictor may learn the spurious associations and influences created by the confounder (age, in this case) instead of the actual biomarkers of the disease10.

Fig. 1: Confounding effects in deep-learning models.
figure 1

A confounder is a variable that influences both the input and the output of a study causing spurious association, if not properly controlled for.

Traditionally, studies control for the impact of confounding variables by eliminating their influences on either the output or the input variables. With respect to the output variables, one can reduce the dependency to confounders by matching confounding variables across cohorts (during data collection)7 or through analytical approaches, such as standardization and stratification11,6. Associations between confounders and input variables are frequently removed by regression analysis12,6, which produces residualized variables that are regarded as the confounder-free input to the prediction models.

The most advanced image-based prediction models are based on convolutional neural networks (ConvNets)13,1,14,15,16,3. A standard ConvNet contains a feature extractor (\({\mathbb{FE}}\)) followed by a classifier/predictor (\({\mathbb{P}}\)). \({\mathbb{FE}}\) reduces each medical image to a vector of feature F, based on which the fully connected layers of \({\mathbb{P}}\) predict a binary or continuous outcome y (Fig. 2a). Unlike traditional machine-learning models, ConvNets require large training data sets and adopt end-to-end learning strategy to extract feature F on-the-fly from the raw image X. This renders the above methods to account for confounders unsuitable as they either result in reduced number of training samples (e.g., matching or stratification) or require deterministic features that are computed beforehand (e.g., standardization or regression). Possible alternatives could be unbiased17,18,19,20,21 and invariant feature-learning approaches22,23,24,3). We visually confirmed this finding by projecting the high-dimensional F of each control subject into 2D via t-SNE32. Figure 3d shows each subject as a point, whose color was defined by their age. While older subjects are concentrated on the upper-left region in the feature space associated with ConvNet, a clear pattern with respect to age was not visible for the projections associated with CF-Net (Fig. 3e).

To gain more insight into which anatomical regions drove the predictions, Fig. 3f, g visualizes the saliency maps33 of ConvNet and CF-Net with yellow, highlighting areas that the predictions heavily relied upon. Figure 3f reveals that the ConvNet-extracted features close to the ventricles and cerebellum, which were crucial markers for brain aging34 omitted by CF-Net. On the other hand, CF-Net produced higher saliency in the precentral and postcentral gyri, which are frequently linked to alternations in cortical structure and function in HIV-infected patients35,36. Other regions with high average saliency according to CF-Net are located in the temporal lobe, inferior frontal gyrus, and subcortical regions, including the amygdala and hippocampus. These regions (except for the amygdala) also exhibited significant white-matter tissue loss due to HIV according to a traditional voxel-based morphometry analysis37 (Supplementary Fig. 4).

Brain morphological sex differences in adolescent brains of the NCANDA study

The public dataset (Release: NCANDA_PUBLIC_BASE_STRUCTURAL_IMAGE_V0138) consisted of the baseline T1-weighted MRI of 334 boys and 340 age-matched girls (age 12–21 years, p > 0.5, two-sample t-test) from the National Consortium on Alcohol and NeuroDevelopment in Adolescence (NCANDA)39 that met the no-to-low alcohol drinking criteria of the study. The confounder of the study was the pubertal development score (PDS, Fig. 4a)39, which was significantly higher (p < 0.001, two-sample t-test) in girls (3.41 ± 0.6) than boys (2.86 ± 0.7).

Fig. 4: Sex prediction from adolescent brain MRIs.
figure 4

a Significantly different pubertal development scores (PDS) between n = 334 boys and n = 340 girls (p < 0.0001 two-tailed two-sample t-test). Boxplots are characterized by minimum, first quartile, median, third quartile, and maximum. b, c Sex-prediction scores measured on all subjects and the c-independent subset containing n = 200 boys and n = 200 girls. d, e t-SNE visualization of the feature space learned by the deep-learning models. f, g Saliency maps of sex differences.

With respect to the ConvNet baseline, the results from the previous experiment were largely replicated. Based on 5-fold cross-validation, the accuracy in predicting sex dropped from 90.3% across all samples to 87.3% (Table 2) on a c-independent subset, which consisted of 200 boys and 200 girls with the same PDS distribution (3.14 ± 0.65). Being significantly confounded by PDS, the ConvNet produced a lower balanced accuracy (BAcc: 79.5%) for subjects in the early pubertal stage compared with an accuracy score of 90.6% for subjects in later stages (subcohorts divided by the mean PDS of 3.2). As boys had significantly lower PDS, the ConvNet tended to label girls with small PDS as boys (recall: 68.1%, Fig. 4b). Although the t-SNE projection of the ConvNet features showed less pronounced correlation with PDS compared with the HIV experiment (Fig. 4d), the confounding effect of PDS still significantly impacted the derived features as revealed by the post hoc training of \({\mathbb{CP}}\) (Pearson’s r = 0.84, p < 0.001, Supplementary Fig. 7). Last, sex prediction of ConvNet was mostly based on the parietal inferior lobe, supramarginal region, cerebellum, and subcortical regions according to the saliency map of Fig. 4f (Supplementary Fig. 9).

Table 2 BAcc (precision and recall) on predicting sex from MRIs of NCANDA matched with respect to PDS. Optimal results were achieved when conditioning CF-Net on boys.

For CF-Net, the accuracy depended on the set of subjects used for training the component \({\mathbb{CP}}\), which, unlike in the HIV experiment, was not uniquely defined as the modeling of the PDS effect that could be conditioned on y = 0 (boys) or y = 1 (girls). According to Table 2, conditioning the training of \({\mathbb{CP}}\) on boys resulted in more accurate predictions in the c-independent subset and recorded a smaller gap in accuracy across subjects at different pubertal stages, while conditioning on girls not only reduced the BAcc, but also enlarged the discrepancy in precision and recall rates. As expected, similar degraded performance was also observed when training \({\mathbb{CP}}\) on subjects of both sexes without conditioning (CF-Net (All) in Table 2). Among the three implementations of CF-Net, only the CF-Net conditioned on boys was significantly more accurate in prediction at the early pubertal stage (two-tailed p = 0.039, DeLong’s test) and produced features significantly less predictive of PDS (p < 0.001, one-sample t333 = 12.2, Supplementary Figs. 7 and 8) compared to ConvNet (Fig. 4b–e). Interestingly, the saliency map associated with this CF-Net implementation (Fig. 4f, g, Supplementary Fig. 9) focused only on subcortical regions.

Bone-age prediction from hand X-ray images

The dataset consisted of hand X-ray images of 12,611 children (6833 boys and 5778 girls) that were released by the Radiological Society of North America (RSNA) Radiology Informatics Committee (RIC) as a machine-learning challenge for predicting pediatric bone age40. The confounder in this study was sex as boys were significantly older than girls (boys: 134.8 ± 42.2 months, girls: 118.7 ± 38.2 months). We randomly chose 75% of the images (N = 9458) as training data and the remaining as validation data (N = 3153). The ConvNet was based on the publicly released implementation by the Kaggle challenge41. The feature extractor consisted of a pretrained VGG-16 backbone followed by an attention module41. This ConvNet achieved a mean absolute error of 13.8 months in predicting age from the X-rays of the validation set. The model tended to overestimate the age of girls compared to boys (Fig. 5b), and this discrepancy was more pronounced in the age range of 110–200 months (Fig. 5c).

Fig. 5: Bone-age prediction from hand X-ray images.
figure 5

a Difference in the age distribution between n = 6, 833 boys and n = 5, 778 girls of the RSNA bone-age dataset (p < 0.0001, two-tailed two-sample t-test). b Ground truth vs. predicted age of the ConvNet. ConvNet tended to predict higher age for girls than boys, indicating a confounding effect of sex. c This prediction gap between boys and girls was more pronounced in the age range of 110–200 months, but was significantly reduced by CF-Net, which modeled the dependency between F and c on a y-conditioned cohort. d Absolute prediction error (in months) of n = 3, 153 testing subjects produced by ConvNet and CF-Net with (or without) conditioning. Boxplots are characterized by minimum, first quartile, median, third quartile, and maximum. CF-Net with conditioning resulted in the most accurate prediction (p < 0.0001, two-tailed two-sample t-test).

Next, we aimed to remove sex-related confounding effects in the attention module by CF-Net. Since the ConvNet was based on a VGG-16 feature extractor pretrained on the large number of natural images provided by ImageNet, it was unlikely to contain confounding information for X-ray image. Hence, we only applied the \({\mathbb{CP}}\) component to adjust parameters of the attention module, but kept the VGG-16 feature extractor fixed. However, y was now a continuous variable as opposed to a binary one used in the previous experiments, so the y-conditioned cohort could not be defined with respect to a fixed prediction outcome. Instead, we applied the \({\mathbb{CP}}\) component to a bootstrapped-training set of 10,000 boys and 10,000 girls whose age was confined to the interval from 75 to 175 months and had strictly matched distributions between the two genders (see “Methods” section). By doing so, CF-Net successfully reduced the sex-related gap in age prediction (Fig. 5c, Supplementary Figs. 1113). Moreover, the prediction accuracy of CF-Net with y-conditioning was significantly higher (absolute error 11.2 ± 8.7 months) than that of the baseline ConvNet and CF-Net without y-conditioning (two-tailed p < 0.0001, one-sample t3152 = 14.2, Fig. 5d). The saliency maps of CF-Net were more localized on anatomical structures than those of ConvNet, indicating that the widespread pattern leveraged by ConvNet might be redundant and relate to confounder-related cues. Note, as in the prior experiment, the accuracy of CF-Net was similar to ConvNet when training \({\mathbb{CP}}\) on all subjects available (without conditioning on y).

Discussion

Accurate modeling of confounders is an essential aspect in analyzing medical image2. For example, traditional machine-learning models rely on precomputed features from which confounding effects are regressed out a priori7,10,6. This topic, however, is largely overlooked by deep-learning applications as researchers shift attention to designing deeper and more powerful network architectures to achieve higher classification/regression accuracy42,43,44. Indeed, end-to-end learning of deep models often is superior to traditional machine-learning methods relying on precomputed features. For example, the ConvNet baseline reported a higher accuracy (BAcc: 71.6%) in the HIV experiment than applying a traditional SVM classifier to the 298 brain regional measurements extracted by FreeSurfer45 (BAcc: 69.5%). The more accurate predictions of such deep models are in part due to increased sensitivity to subtle group differences, which also heightens the risk of biasing findings as these subtle differences may relate to confounders. For example, on the NCANDA dataset, ConvNet produced the highest prediction accuracy on the entire cohort, which was partially attributed to the confounding effect of PDS. Therefore, the superiority of a prediction model for medical imaging applications should be defined with respect to its predictive power and impartiality to confounders. However, the a priori strategies (used by traditional machine learning) for training impartial predictors do not work for end-to-end learning models as learning is based on extracting features on-the-fly from raw images. While recent advances in adversarial learning have shed light on this problem, existing deep models were only designed to tackle specific confounding effects such as scanner difference or dataset harmonization46,47,48. Here, we propose a deep-learning architecture for systematically and accurately modeling confounders in medial image applications based on adversarially training a confounding predictor \({\mathbb{CP}}\) (see Fig. 2). \({\mathbb{CP}}\) can be used to remove confounding effects of any layer of a generic deep model, such as the entire feature extractor in the MRI experiments or a submodule of the extractor in the bone-age experiment.

By explicitly modeling the confounding effect in the feature-learning process, CF-Net bypasses the need of matching cohorts with respect to confounders, which generally reduces the sample size and thus negatively impacts generalizability of the model13. However, training models on confounded data now requires evaluating the fairness of model predictions with respect to confounders. In line with the concept of group fairness or demographic parity49,50, one can do so by examining whether the predictive power of the model varies across different validation subsets. We did so by measuring the difference between the testing accuracy recorded on the whole (confounded) cohort and on the c-independent (unconfounded) subset. We viewed this difference as a metric for the severity of the confounding effects: the larger the difference, the more confounded the model. Another way of defining validation subsets is to group testing subjects according to their confounder values (see Figs. 3b, c and 4b, c). In all three experiments, CF-Net achieved more balanced prediction accuracies across those subsets than ConvNet, further highlighting the fairness of the CF-Net model.

Another important property of CF-Net is its ability to model continuous confounders (e.g., age), whereas most existing fair machine-learning methods17,23,51,52,53,56 of the overall age distribution, which approximately encompassed 80% of the training subjects and focused only on the age range with sufficient samples (Supplementary Fig. 10). This well-represented age interval facilitated the decorrelation with respect to gender and resulted in a large y-conditioned cohort for training \({\mathbb{CP}}\). Another strategy for defining the interval (not explored in this paper) is to model the interval as a hyperparameter, whose optimal setting is determined via parameter exploration during nested cross-validation. Alternatively, one can bypass the need of selecting the interval by using data-driven matching procedures (e.g., a bipartite graph matching57 or greedy algorithm7), which in our experiments produced similar accuracy scores as the one based on the FWHM criteria and bootstrap**.

Based on these different y-conditioning strategies, medical researchers can use CF-Net to train deep models on cohorts not strictly matched with respect to confounders without discarding unmatched samples. However, this does not mean that there is no need to keep the confounders in mind when recruiting participants for medical imaging studies. For all learning models, performing analysis on confounder-matched cohorts with sufficient samples remains a fundamental strategy to disentangle biomarkers of interest from the effects of confounders. For example, in the bone-age experiment, recruiting enough age–gender-matched samples resulted in a large y-conditioned cohort that reduces the risk of overfitting during the training of \({\mathbb{CP}}\). Conversely, if two cohorts have completely different distributions with respect to a confounder (e.g., one has participants with strictly larger age than the other), there is no guarantee that any method, including ours, can remove the bias in a purely data-driven fashion. Therefore, in the study-design stage, defining potential confounders for a specific medical application may require domain-specific knowledge to maximize the power of CF-Net in practice.

A limitation of our experiments was the focus on single confounders that were known a priori. To model unknown confounders, we aim to explore coupling CF-Net with causal discovery algorithms (such as refs. 58,59,60). In case predictions are biased by multiple confounders, we would need to extend \({\mathbb{CP}}\) to predict multiple outputs (one for each confounder) or add for each confounder a \({\mathbb{CP}}\) component to CF-Net. In the simple scenario that the confounding variables are conditionally independent with respect to y, each \({\mathbb{CP}}\) component can be trained on a separate y-conditioned cohort uniquely defined for each confounder. However, theoretical and practical ways in modeling high-order interactions between confounders require further investigation.

While we were able to visualize the HIV and sex effect by computing saliency maps61 inferred from the predictor \({\mathbb{P}}\), the same technique is not directly applicable to visualize confounding effects from \({\mathbb{CP}}\) due to the adversarial training. An alternative could be deriving saliency maps from \({\mathbb{CP}}\) retrained on the features learned by the baseline ConvNet (e.g., Supplementary Fig. 2), i.e., a model that substantially captures the confounding effect.

Finally, we abstained from determining the optimal implementation of the proposed confounder-free modeling strategy by performing extensive exploration of network architectures. Instead, we relied on some of the most fundamental network components used in deep learning. This rather basic implementation still recorded reasonable prediction accuracies, so the findings discussed here are likely to generalize to more advanced network architectures.

Methods

Materials

This study used multiple medical imaging data sets to evaluate different aspects of our proposed confouder-free neural network, described briefly herein. In addition, experiments on synthetic data sets are included in Supplementary Fig. 1, which shows the efficacy of our proposed framework in controlled settings.

HIV dataset: Our first task aimed at predicting the diagnosis of HIV patients vs. control subjects62. Participants ranged in age between 18 and 86 years and were all scanned with a T1-weighted MRI. All study participants provided written informed consent, and the study was approved by Institutional Review Board (IRB) at Stanford University (Protocol ID: IRB-9861) and SRI International (Protocol ID: Pro00039132). HIV subjects were seropositive for the HIV infection with CD4 count \(> 100\,\frac{\,\text{cells}\,}{\rm{\mu L}}\) (average: 303.0). Construction of the c-independent subset was based on the matching algorithm7 that extracted the maximum number of subjects from each group in such a way that they were equal in size and identically distributed with respect to the confounder values. For each HIV subject, we selected a control subject with minimal age difference and repeated this procedure until all HIV subjects were matched or the two-tailed p value of the two-sample t-test between the two age distributions dropped to 0.5. The MR images were first preprocessed7 by denoising, bias-field correction, skull stri**, and affine registration to the SRI24 template63. The registered images were then downsampled to a 64 × 64 × 64 volume64 based on spline interpolation to reduce the potential overfitting during training and to enable a large batch size. Prediction accuracy of the deep models was determined via fivefold cross-validation. For each training run, MRIs were augmented to provide sufficient number of samples for the model to be trained on. As in ref. 65, data augmentation produced new synthetic 3D images by randomly shifting each MRI within one voxel and rotating within 1 along the three axes. The augmented dataset included a balanced set of 1024 MRIs for each group (control and HIV). Assuming that HIV affects the brain bilaterally7,66, the left hemisphere was flipped to create a second right hemisphere. During testing, the right and flipped left hemispheres of the raw test images were given to the trained model, and the prediction score averaged across both hemispheres was used to predict the individual’s diagnosis group. Last, a saliency map was computed61 for the right hemisphere of each test image quantifying the importance of each voxel to the final prediction.

NCANDA dataset: Experiments were performed on the baseline T1 MR images of 334 boys and 340 girls from the NCANDA study (Public Release: NCANDA_PUBLIC_BASE_STRUCTURAL_V0167). Adult participants and the parents of minor participants provided written informed consent before participation in the study. Minor participants provided assent before participation. The IRB of each site approved the standardized data collection and use39. All subjects met the no-to-low alcohol drinking criteria of the study, and there was no significant age difference between boys and girls (p > 0.5, two-sample t-test). Pubertal stage was determined by the self-assessment pubertal development scale (PDS). Procedures for preprocessing, downsampling, and classifying the MRI were conducted according to the HIV experiment.

Bone-aging dataset: The RSNA Pediatric Bone Age Machine Learning Challenge was based on a dataset consisting of 14,236 hand radiographs (12,611 training sets, 1425 validation sets, and 200 test sets)40. We experimented on the 12,611 training images with ground-truth bone age (127.3 ± 41.2) and the ConvNet model publicly released on the Kaggle challenge page41. In total, 3914 boys and 3518 girls, or 80% of the training subjects (Fig. 5a), had bone ages between 75 months and 175 months (the FWHM of the age distribution, Supplementary Fig. 10). Confined to this age range, we used bootstrap**68 to generate 1000 boys and 1000 girls within each 10-month interval. This procedure resulted in a y-conditioned cohort of 10,000 boys and 10,000 girls strictly matched with respect to bone age (p = 0.19, two-tailed two-sample t-test).

Confounder-free neural network (CF-Net)

Suppose we have N-training MR images \({\boldsymbol{{\mathcal{X}}}}={\{{{\bf{X}}}_{i}\}}_{i = 1}^{N}\) and their corresponding target-prediction values \({\bf{y}}={\{{y}_{i}\}}_{i = 1}^{N}\), where yi \(\in\) [0, 1] for classification problems and is a continuous variable for regression problems. Let us assume that the study is confounded by a set of k variables and their values are denoted by \({\bf{C}}={\{{{\bf{c}}}_{i}\}}_{i = 1}^{N}\), where each \({{\bf{c}}}_{i}=[{c}_{i}^{1},...,{c}_{i}^{k}]\) is a k-dimensional vector denoting the k confounders of subject i. To train a deep neural network for predicting the target value for each input MR image X, we first apply a Feature Extraction (\({\mathbb{FE}}\)) network to the image, resulting in a feature vector F. A Classifier (\({\mathbb{P}}\)) is built on this feature vector to predict the target y for the input X. This ensures the discriminative power of the learned features and defines the baseline architecture of ConvNet. Now, to guarantee that these features are not biased to the confounders, we propose our end-to-end architecture as in Fig. 2. Specifically, we build another network (denoted by \({\mathbb{CP}}\)) for predicting the confounding variables from F and backpropagate this loss to the feature-extraction module in an adversarial way. We train \({\mathbb{CP}}\) only on a y-conditioned cohort consisting of subjects whose target y values are uncorrelated with all k confounders. We define the y-conditioned cohort as \({{\mathcal{X}}}_{\rho }\) with ρi = 1 if \({{\bf{X}}}_{i}\in {{\mathcal{X}}}_{\rho }\), and ρi = 0 otherwise. The confounders associated with the y-conditioned cohort are correspondingly denoted as Cρ. As a result, the feature extractor learns features that minimize the y predictor loss while being conditionally independent of the confounder by maximizing the loss of \({\mathbb{CP}}\) for \({{\mathcal{X}}}_{\rho }\).

Each of the above networks have some underlying trainable parameters, defined as θfe for \({\mathbb{FE}}\), θp for \({\mathbb{P}}\), and θcp for \({\mathbb{CP}}\). \({\mathbb{P}}\) forces the feature extractor to learn features to better predict yi by backpropagating the prediction loss. Let \({\hat{y}}_{i}={\mathbb{P}}({\mathbb{FE}}({{\bf{X}}}_{i};{{\boldsymbol{\theta }}}_{fe});{{\boldsymbol{\theta }}}_{p})\) be the predicted yi, then the prediction loss can be characterized by binary cross-entropy \(l({y}_{i},{\hat{y}}_{i})=-{y}_{i}\mathrm{log}\,{\hat{y}}_{i}-(1-{y}_{i})\mathrm{log}\,(1-{\hat{y}}_{i})\) for classification and by the mean-squared error \(l({y}_{i},{\hat{y}}_{i})={({y}_{i}-{\hat{y}}_{i})}^{2}\) for regression. Finally, the prediction loss for the entire cohort is

$${L}_{p}({\mathcal{X}},{\bf{y}};{{\boldsymbol{\theta }}}_{fe},{{\boldsymbol{\theta }}}_{p})=\frac{1}{N}\mathop{\sum }\limits_{i=1}^{N}l({y}_{i},{\hat{y}}_{i}).$$
(1)

Similarly, with \({\hat{c}}_{i}={\mathbb{CP}}({\mathbb{FE}}({{\bf{X}}}_{i};{{\boldsymbol{\theta }}}_{fe});{{\boldsymbol{\theta }}}_{cp})\), we define the surrogate loss of confounder prediction for the y-conditioned cohort as

$${L}_{cp}({{\mathcal{X}}}_{\rho },{{\bf{C}}}_{\rho };{{\boldsymbol{\theta }}}_{fe},{{\boldsymbol{\theta }}}_{cp})=-\mathop{\sum }\limits_{\kappa =1}^{k}{\text{corr}}^{2}({{\bf{c}}}^{\kappa },{\hat{{\bf{c}}}}^{\kappa }),$$
(2)

where corr2(.,.) is the squared correlation between its inputs and cκ defines the vector of κth confounding variable in Cρ. Hence, the overall objective of the network with a trade-off hyperparameter λ is

$$\mathop{\min }\limits_{{{\boldsymbol{\theta }}}_{fe},{{\boldsymbol{\theta }}}_{p}}\mathop{\max }\limits_{{{\boldsymbol{\theta }}}_{cp}}{L}_{p}({\mathcal{X}},{\bf{y}};{{\boldsymbol{\theta }}}_{fe},{{\boldsymbol{\theta }}}_{p})-\lambda {L}_{cp}({{\mathcal{X}}}_{\rho },{{\bf{C}}}_{\rho };{{\boldsymbol{\theta }}}_{fe},{{\boldsymbol{\theta }}}_{cp}).$$
(3)

This scheme is similar to the GAN formulations29 with a min–max game between two networks. In our case, \({\mathbb{FE}}\) extracts features that minimize the classification criterion, while fooling \({\mathbb{CP}}\) (i.e., making \({\mathbb{CP}}\) incapable of predicting the confounding variables). Hence, the saddle point for this optimization objective is obtained when the parameters θfe minimize the classification loss while maximizing the loss of the confounder-prediction module. Simultaneously, θp and θcp minimize their respective network losses.

Implementation

After normalizing confounder values to z scores, we optimize Eq. (3) based on the practice used in GANs. In each iteration, we first train Lp on a mini batch sampled from all available training data. The loss of Lp was backpropagated to update θfe and θp. With θfe fixed, we then minimize Lcp to update θcp by computing the correlation of Eq. (2) over subjects of a mini batch sampled from the y-conditioned cohort. Finally, with θcp fixed, Lcp is maximized by updating θfe with respect to the correlation loss defined on a mini batch from the y-conditioned cohort.

With respect to the network architecture used in the experiments, we followed the design of \({\mathbb{FE}}\) in refs. 69,64 that contained 4 stacks of 2 × 2 × 2 3D convolution/ReLu/batch-normalization/max-pooling layers, yielding 4096 intermediate features. Each of \({\mathbb{P}}\) and \({\mathbb{CP}}\) was a two-layer fully connected network. We set λ to 1 (see Supplementary Fig. 5) and used a batch size of 64 subjects and Adam optimizer with a learning rate of 0.0002. For the 2D X-ray experiment, the \({\mathbb{FE}}\) and \({\mathbb{P}}\) components complied with the feature extractor and predictor defined in ref. 41.

Reporting summary

Further information on research design is available in the Nature Research Reporting Summary linked to this article.