Introduction

Supervised learning with deep neural networks has achieved state of the art performance in a diverse range of applications. An adequate number of labeled samples is essential for training these systems but most real-world data is unlabeled. Label generation can be cumbersome, expensive and is a major barrier to the development and testing of such systems [1].

Ideally, when confronted with a task and unlabeled data, one would like to estimate how many examples need to be labeled to train a neural network for that task. In this paper, we take a step towards addressing this problem.

Consider a fully connected neural network f of pre-specified dimensions and a dataset X, which is initially unlabeled, but for which labels y can be obtained when needed. We define the minimum convergence size (MCS) for f on X to be the smallest number n such that a subset Xn of n examples drawn at random from X can be labeled and used to train f as a non-trivial classifier, that is, one whose area-under-the-curve (AUC) on a held-out test set is greater than 0.5:

$${{{\rm{MCS}}}}:=\mathop{{{{\rm{arg}}\,{\rm{min}}}}}\limits_{n}(E[{{{\rm{AUC}}}}({f}_{{X}_{n},{y}_{n}}({X}_{test}),{y}_{test})]\, >\, 0.5)$$
(1)

Given that outcomes are balanced, an AUC > 0.5 implies that a model is able to identify some signal in the underlying data, and if that AUC is on the test dataset, this means that the signal identified by the model can generalize to unseen data. In this scenario, below the MCS, we would expect to see little or no correlation between sample size and model performance measured by AUC, whereas above the MCS we would expect to see a positive correlation.

We propose a method for empirically determining the MCS for f on X using only unlabeled data, and we call this estimate the Minimum Convergence Sample Estimate (MCSE). We do this by first constructing an autoencoder g [2], wherein the encoder part has a similar number of parameters and hidden layers as f. We train g on increasingly larger (unlabeled) subsets Xi of X. This may permit similarities in layer-wise learning between f and g. Under these circumstances, we empirically show that, at each step i, the reconstruction loss L of g is related to the generalization performance of f trained on a similarly sized sample. We also demonstrate how this can be used to determine the MCSE for f on X (Fig. 1).

Fig. 1: Minimum convergence sample estimation can be used to approximate the number of labels required for generalizable performance.
figure 1

a A fully connected network is trained on labeled data, and tested on a unlabeled data. Generalizability Performance is measured via AUC. Minimum convergence sample (MCS) reflects the minimum number of labeled samples required for a fully connected network to start generalizing. b An autoencoder with a similar structure as the fully connected network is trained on unlabeled data and the loss function measures how generalizable the FCN is. Minimum convergence sample estimate (MCSE) approximates the minimum convergence sample (MCS).

As an example, consider classification of the MNIST [3] dataset with a fully connected neural network (Fig. 2). A comparison of the test set AUC curve of f and the loss curve of autoencoder g shows that their inflection points occur at similar sample sizes. We then define the MCSE for f on MNIST as the sample size corresponding to the inflection point in the loss function of g:

$$\begin{array}{lll}{{{\rm{MCSE}}}}&:=&\mathop{{{{\rm{arg}}\,{\rm{max}}}}}\limits_{n}\frac{{{{{\rm{d}}}}}^{2}L}{{{{\rm{d}}}}{n}^{2}}\\ &\approx &4.5\,\,({{{\rm{in}}}}\,{{{\rm{this}}}}\,{{{\rm{case}}}})\end{array}$$
(2)
Fig. 2: Comparison of autoencoder learning curve with generalizability of a fully connected network.
figure 2

f is a fully connected network with input dimension of 784 and output dimension 10, and g an autoencoder with an input dimension of 784 and a latent space of dimension 3. a The loss of the autoencoder displays a curve split into two phases: the quick phase and the slow phase. b The first derivative of the autoencoder loss function displays a decay phase and a growth phase. c The second derivative reveals a sharp inflection point where the slope changes from sharply decreasing to sharply increasing. d The area-under-the-curve metric on the test set displays a biphasic structure: a rapid growth phase and a slow growth phase. e The first derivative of the AUC curve reveals a rapidly increasing phase followed by a decreasing phase. f The second derivative of the AUC curve reveals an inflection point as a mirror image of the autoencoder loss curve.

With sample sizes above MCSE, the learnability of the dataset on f may be approximated by the ease with which g is able to embed a latent space that fully represents the data. We hypothesized the following relationship between generalization power of a classifier with respect to learnability of the dataset by the corresponding autoencoder:

$$\frac{{{{\rm{d}}}}{{{\rm{AUC}}}}}{{{{\rm{d}}}}n}{{{\rm{AUC}}}}({f}_{n})\simeq \left\{\begin{array}{ll}0,&{{{\rm{for}}}}\,n \,<\, {{{\rm{MCSE}}}}\\ -\beta \frac{{{{\rm{d}}}}L}{{{{\rm{d}}}}n}L({g}_{n}),&{{{\rm{for}}}}\,n\ge {{{\rm{MCSE}}}}\\ \end{array}\right\}$$
(3)

β is a scaling constant. We tested this hypothesis by calculating the correlation coefficient below the MCSE and above the MCSE, and results are reported in Table 1. A significant R2 indicates a linear correlation between loss and power. We used eight different standard computer vision datasets to demonstrate our method. MNIST, EMNIST [4], QMNIST [5], KMINST [6] are character-recognition datasets composed of 28x28 pixel grayscale images. FMNIST[30]. Other pre-hoc methods such as empirical process theory did not extend well to non-linear methods [16]. Post-hoc methods usually involve fitting a learning curve, but fitting a learning curve is trivial for minimum convergence sample estimation because any amount of data should result in a non-zero increase in performance on a training data-set. Moreover, these methods are task-specific, data-specific, and model-specific, as one learning curve has no relevance outside that specific task, model and data-set. Nevertheless, while our experiments validate minimum convergence sample estimation on toy data-sets, synthetic data, and one real-world example of medical imaging due to data availability, future work should further validate this method on more across different tasks and imaging types in the healthcare context.

Our second contribution is the proposal of a method to empirically estimate MCSE for a given fully connected neural network f. This function allows users to predict statistical power of a model without needing to train on the entire training set during every trial. It also includes an uncertainty on the estimate, in which the variance is inversely correlated to how structured the underlying data is. Our third contribution is a publicly available tool for minimum sample size estimation for fully connected neural networks.

Importantly, there are several natural opportunities to extend our work to more complex models, as discussed below. First, our paper only considered a fully connected network with a relatively simple architecture. One natural question that might extend from this work involves assessing how this method fares in estimating the statistical power of convolutional or recurrent neural networks. While adding convolutions would be relatively easy to do via the addition of another layer, adding attention mechanisms may require additional structural modifications to fully approximate the statistical power of recurrent neural network or transformers. For our method to be applicable to medical imaging tasks, we anticipate that extending this work to convolutional neural networks remains an important next step. Future work can validate MCSE on more complex architectures utilizing pre-trained networks and skip connections. Second, the loss function that was utilized in this current analysis was the reconstruction loss, which is a relatively simple choice of loss function. For variational autoencoders, the loss function changes to instead use a KL-divergence, while GANs use JS-divergence and WGANs use Wasserstein divergence [31,3).

After taking the second derivative of the loss function, we located the inflection point and its respective sample size as well as the value of the autoencoder loss at that value. Figure 2 was generated when we plotted the autoencoder loss of g and the AUC of f against the sample size. MCSE was drawn as a vertical line in Fig. 3, coinciding with the inflection point on the autoencoder loss function and provides a lower bound on the sample size required to improve model performance. The shaded area bars represent the error, determined as the autoencoder loss at the log(MCSE ± 1) sample sizes.

Third, we determined the correlation between autoencoder loss and area-under-the-curve using R2, Kendall’s τ and Spearman’s ρ (Table 1). These values demonstrated a significant coorelation above, but not below, the MCSE (inflection point of the autoencoder loss curve). To better demonstrate this finding, we plotted out the results for Eq. (3) for values above MCSE (Fig. 3).

Finally, we generalize these results to an n-dimensional hyper-cube and validate them on a medical imaging dataset. To generate the synthetic data set, we randomly sampled data-points from an n-dimensional hyper-cube with side-lengths equal to the class separation to generate the fully connected network classifier and an autoencoder. For the medical dataset, we use the publicly available Kaggle Chest X-ray dataset [34], and accurately predict the minimum number of labeled samples required to learn a meaningful classifier using a fully connected network.

Analysis of the publicly available NIH CXR dataset was carried out with approval of the Institutional Review Board at Icahn School of Medicine at Mount Sinai, New York, NY 10019. The requirement for informed consent was waived as the dataset was completely de-identified.

Reporting summary

Further information on research design is available in the Nature Research Reporting Summary linked to this article.