Abstract
Translating machine learning algorithms into clinical applications requires addressing challenges related to interpretability, such as accounting for the effect of confounding variables (or metadata). Confounding variables affect the relationship between input training data and target outputs. When we train a model on such data, confounding variables will bias the distribution of the learned features. A recent promising solution, MetaData Normalization (MDN), estimates the linear relationship between the metadata and each feature based on a non-trainable closed-form solution. However, this estimation is confined by the sample size of a mini-batch and thereby may cause the approach to be unstable during training. In this paper, we extend the MDN method by applying a Penalty approach (referred to as PDMN). We cast the problem into a bi-level nested optimization problem. We then approximate this optimization problem using a penalty method so that the linear parameters within the MDN layer are trainable and learned on all samples. This enables PMDN to be plugged into any architectures, even those unfit to run batch-level operations, such as transformers and recurrent models. We show improvement in model accuracy and greater independence from confounders using PMDN over MDN in a synthetic experiment and a multi-label, multi-site dataset of magnetic resonance images (MRIs).
Access this chapter
Tax calculation will be finalised at checkout
Purchases are for personal use only
Similar content being viewed by others
References
Adeli, E., et al.: Chained regularization for identifying brain patterns specific to HIV infection. Neuroimage 183, 425–437 (2018)
Adeli, E., et al.: Deep learning identifies morphological determinants of sex differences in the pre-adolescent brain. Neuroimage, 223, 117293 (2020)
Agarwal, A., Kakade, S.M., Lee, J.D., Mahajan, G.: On the theory of policy gradient methods: optimality, approximation, and distribution shift. J. Mach. Learn. Res. 22(98), 1–76 (2021)
Ba, J.L., Kiros, J.R., Hinton, G.E.: Layer normalization. ar**v preprint ar**v:1607.06450 (2016)
Baharlouei, S., Nouiehed, M., Beirami, A., Razaviyayn, M.: R\(\backslash \)’enyi fair inference. ar**v preprint ar**v:1906.12005 (2019)
Chen, J., et al.: Transunet: transformers make strong encoders for medical image segmentation. ar**v preprint ar**v:2102.04306 (2021)
Delano-Wood, L., et al.: Heterogeneity in mild cognitive impairment: differences in neuropsychological profile and associated white matter lesion pathology. J. Int. Neuropsychol. Soc. 15(6), 906–914 (2009)
Deshmukh, S., Khaparde, A.: Faster region-convolutional neural network oriented feature learning with optimal trained recurrent neural network for bone age assessment for pediatrics. Biomed. Signal Process. Control, 71, 103016 (2022)
Dosovitskiy, A., et al.: An image is worth 16 \(\times \) 16 words: transformers for image recognition at scale. In: International Conference on Learning Representations (2021). https://openreview.net/forum?id=YicbFdNTTy
Hara, K., Kataoka, H., Satoh, Y.: Learning spatio-temporal features with 3d residual networks for action recognition. In: Proceedings of the IEEE International Conference on Computer Vision Workshops, pp. 3154–3160 (2017)
Ioffe, S., Szegedy, C.: Batch normalization: accelerating deep network training by reducing internal covariate shift. In: International Conference on Machine Learning, pp. 448–456. PMLR (2015)
Kingma, D.P., Ba, J.: Adam: a method for stochastic optimization. ar**v preprint ar**v:1412.6980 (2014)
Lahiri, A., Alipour, K., Adeli, E., Salimi, B.: Combining counterfactuals with shapley values to explain image models. ar**v preprint ar**v:2206.07087 (2022)
Lin, T.Y., Goyal, P., Girshick, R., He, K., Dollár, P.: Focal loss for dense object detection. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2980–2988 (2017)
Liu, T.Y., Kannan, A., Drake, A., Bertin, M., Wan, N.: Bridging the generalization gap: Training robust models on confounded biological data. ar**v preprint ar**v:1812.04778 (2018)
Liu, X., Li, B., Bron, E.E., Niessen, W.J., Wolvius, E.B., Roshchupkin, G.V.: Projection-wise disentangling for fair and interpretable representation learning: application to 3D facial shape analysis. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12905, pp. 814–823. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87240-3_78
Lu, M., et al.: Metadata normalization. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10917–10927 (2021)
Van der Maaten, L., Hinton, G.: Visualizing data using t-sne. J. Mach. Learn. Res. 9(11), 2579–2605 (2008)
Neto, E.C.: Causality-aware counterfactual confounding adjustment for feature representations learned by deep models. ar**v preprint ar**v:2004.09466 (2020)
Petersen, R.C., et al.: Alzheimer’s disease neuroimaging initiative (ADNI): clinical characterization. Neurology 74(3), 201–209 (2010)
Robbins, H., Monro, S.: A stochastic approximation method. Ann. Math. Stat. 22(3), 400–407 (1951)
Tartaglione, E., Barbano, C.A., Grangetto, M.: End: entangling and disentangling deep representations for bias correction. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13508–13517 (2021)
Vaswani, A., et al.: Attention is all you need. Advances in neural information processing systems 30 (2017)
Yao, Z., Cao, Y., Lin, Y., Liu, Z., Zhang, Z., Hu, H.: Leveraging batch normalization for vision transformers. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 413–422 (2021)
Yong, H., Huang, J., Meng, D., Hua, X., Zhang, L.: Momentum batch normalization for deep learning with small batch size. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12357, pp. 224–240. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58610-2_14
Kwon, D., et al.: Extracting patterns of morphometry distinguishing HIV associated neurodegeneration from mild cognitive impairment via group cardinality constrained classification. Hum. Brain Mapp. 37(12), 4523–4538 (2016)
Zhao, Q., Adeli, E., Pohl, K.M.: Training confounder-free deep learning models for medical applications. Nat. Commun. 11(1), 1–9 (2020)
Zhong, G., Wang, L.N., Ling, X., Dong, J.: An overview on data representation learning: from traditional feature learning to recent deep learning. J. Finan. Data Sci. 2(4), 265–278 (2016)
Acknowledgements
This study was partially supported by NIH Grants (AA017347, MH113406, and MH098759) and Stanford Institute for Human-Centered AI (HAI) Google Cloud Platform (GCP) Credit.
Author information
Authors and Affiliations
Corresponding author
Editor information
Editors and Affiliations
Rights and permissions
Copyright information
© 2022 The Author(s), under exclusive license to Springer Nature Switzerland AG
About this paper
Cite this paper
Vento, A., Zhao, Q., Paul, R., Pohl, K.M., Adeli, E. (2022). A Penalty Approach for Normalizing Feature Distributions to Build Confounder-Free Models. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds) Medical Image Computing and Computer Assisted Intervention – MICCAI 2022. MICCAI 2022. Lecture Notes in Computer Science, vol 13433. Springer, Cham. https://doi.org/10.1007/978-3-031-16437-8_37
Download citation
DOI: https://doi.org/10.1007/978-3-031-16437-8_37
Published:
Publisher Name: Springer, Cham
Print ISBN: 978-3-031-16436-1
Online ISBN: 978-3-031-16437-8
eBook Packages: Computer ScienceComputer Science (R0)