1 Introduction

In a restricted multi-center learning environment where each chunk of data is only available at the corresponding center, we should learn a model incrementally without previous data chunks. Consider the scenario in which privacy-sensitive medical data are spread across multiple hospitals such that a machine learning model has to be learned sequentially. If all data are available to be used concurrently, learning just with state-of-the-art deep learning models such as ResNet for image recognition [5] or GNMT for machine translation [15] can be a good solution. However, if a data chunk from one stage is not available anymore in the following learning stages, it is hard to preserve the knowledge learned from the old data chunk because of the phenomenon known as catastrophic forgetting [4]. This becomes more problematic especially in neural networks optimized with gradient descent [12].

Overcoming catastrophic forgetting is one of the key research topics in deep learning. One naive approach is to fine-tune (FT) the model with the data accessible at each stage by learning from the up-to-date model parameters [2]. Learning without Forgetting (LwF) is a representative method for overcoming catastrophic forgetting in neural networks [11]. Before starting training in the current stage, output logits (LwF-logits) of the current training examples are calculated first, so that each example is paired with its true label and also the pre-calculated LwF-logit. The LwF-logits are used as pseudo labels for preserving old knowledge. Elastic Weight Consolidation (EWC) maintains old knowledge by constraining important weights (i.e. model parameters) not to vary too much [8]. The relative importance between weights is defined based on Fisher information matrix. Deep Generative Replay (GR) [13] uses a generative adversarial network [3]. GR learns a generative model and a task solving model at the same time, and the learned generator is used for sampling old data during current learning stage. The concept of GR is interesting, but samples from generative models are not suitable for use in certain applications such as medical imaging where pixel-level details include important radiographic features for diagnosis.

LwF and EWC are representative approaches for preventing catastrophic forgetting in neural networks based on two distinctive philosophies: controlling the output activation (LwF) or the model parameters (EWC). In this work, we preserve knowledge by modeling the feature space directly.Footnote 1 Based on the assumption that there exists better feature space for knowledge preservation, we model the high-level feature space and the output (logit) space to be mutually informative each other, and constrain the feature space to be in the modeled space during training. With experimental validation, we show that the proposed method preserves more knowledge than previous approaches.

2 Baseline Models

LwF and EWC are originally proposed for preventing catastrophic forgetting in multi-task learning where each task has its own data and the data used in previous tasks are not available when solving the current task. We call this as multi-center multi-task learning. We focus on multi-center single-task learning where the model is learned with different data-chunk of the same task and access to each data-chunk is restricted. In this section, we define several baseline models for the multi-center single-task learning environment.

Fine-tuning (FT) trains a model incrementally based on the model parameters learned in the previous stage. Figure 1(a) shows the model architecture for FT. \(X_n\), Z, and \(Y_n\) are random variables for the input, latent, and output spaces, respectively. Target loss function \(L_n(\theta )\) (e.g., negative-log-likelihood for classification) optimizes the model parameters \(\theta \) which consist of \(\theta _s\) (shared) and \(\theta _n\) (new). In the first stage, \(\theta \) is randomly initialized. In the following stages, \(\theta \) is restored from the model learned in the previous stage.

Learning without Forgetting (LwF) trains a model using both ground-truth labels and pseudo labels (pre-calculated LwF-logits). Figure 1(b) demonstrates the K-th learning stage. \(Y_n\) and \(Y_{o_i}\) are the model’s output for the current and the i-th stages for i in \(\{1, ... , K-1\}\). The loss function is described as,

$$\begin{aligned} L(\theta ) = L_n(\theta ) + L_{LwF}(\theta ), \qquad L_{LwF}(\theta ) = \sum _i \lambda _{LwF} L_{o_i}(\theta ), \end{aligned}$$
(1)

where \(L_n(\theta )\) is the loss between the model output \(y_n \in Y_n\) and its ground-truth label. \(L_{o_i}(\theta )\) is the loss between the model output \(y_{o_i} \in Y_{o_i}\) and its LwF-logit, and \(\lambda _{LwF}\) is a weighting constant. \(\theta _s\) and \(\theta _n\) are initialized randomly in the first stage and restored from the previous stage in the following stages. In the K-th stage, \(\theta _{o_{K-1}}\) is initialized with \(\theta _n\) of the (\(K-1\))-th stage and fine-tuned until the final stage. In the third stage, for example, \(\theta _{o_1}\) and \(\theta _{o_2}\) are restored from \(\theta _{o_1}\) and \(\theta _n\) of the second stage, respectively. For classification tasks, \(L_n(\theta )\) and \(L_{o_i}(\theta )\) are typically the cross-entropy loss.

Fig. 1.
figure 1

Model architectures: (a) FT/EWC, (b) LwF, and (c) modified LwF (LwF+).

In the multi-center multi-task learning environment, LwF preserves old knowledge by constraining the outputs of the old task-specific layers with corresponding pseudo labels. But, finding out the optimal feature space in terms of all the tasks becomes hard as the number of tasks (i.e. output branches) increases.

Modified LwF (LwF+): LwF can be modified for the multi-center single-task learning. All the previous task-specific layers are merged into a single knowledge-preserving layer as shown in Fig. 1(c). So the loss function becomes,

$$\begin{aligned} L(\theta ) = L_n(\theta ) + L_{LwF+}(\theta ), ~~~~ L_{LwF+}(\theta ) = \lambda _{LwF+} L_{o}(\theta ), \end{aligned}$$
(2)

where \(L_{o}(\theta )\) is the loss between \(y_o \in Y_o\) and its pseudo label (LwF-logit). \(\theta _s\) and \(\theta _n\) are initialized randomly in the first stage and restored from the previous model in the following stages. \(\theta _o\) is initialized with \(\theta _n\) from the first stage and fine-tuned until the end of the learning stages.

Elastic Weight Consolidation (EWC) constrains the model parameters by defining the importance of weights. Each parameter has its own weight-decay constant; the more important a parameter is, the larger the weight-decay constant. Based on the model in Fig. 1(a), the loss function is,

$$\begin{aligned} L(\theta ) = L_n(\theta ) + L_{EWC}(\theta ), ~~~~ L_{EWC}(\theta ) = \sum _j \frac{\lambda _{EWC}}{2} F_j(\theta _j - \theta _{p,j}^*)^2, \end{aligned}$$
(3)

where \(\theta _{p,j}^*\) is the j-th model parameter learned in the previous stage and \(F_j\) is the j-th element of the diagonal of the Fisher matrix F for weighting the j-th model parameter \(\theta _j\). \(\lambda _{EWC}\) is a weighting constant. \(\theta _s, \theta _n\) are randomly initialized in the first stage and restored from the previous model for the following stages.

EWCLwF (EWCLwF+) is the combined model of EWC and LwF (LwF+). Since both methods keep old knowledge based on two distinctive approaches, they can be used complementarily. Based on the model architecture described in Fig. 1(b) with the loss function in Eq. (1), \(L_{EWC}(\theta )\) in Eq. (3) is merged so the loss function becomes \(L(\theta ) = L_n(\theta ) + L_{LwF}(\theta ) + L_{EWC}(\theta )\). EWCLwF+ is similar to EWCLwF. Based on the model LwF+ in Fig. 1(c) with the loss in Eq. (2), target loss becomes \(L(\theta ) = L_n(\theta ) + L_{LwF+}(\theta ) + L_{EWC}(\theta )\).

All the presented models are originated from the two representative methods for knowledge preservation in neural networks. Details of the experimental set-up for the baseline models will be explained in Sect. 4.

3 Proposed Methodology

In a general neural network model as in Fig. 1(a), the output \(Y_n\) of the input data \(X_n\) is compared with its true label, and the error is propagated backward from top to bottom, which encourages the latent variable Z to be task-specific. To keep the previously learned knowledge, the latent space Z should be informative enough to include the information of the input \(X_n\).

Fig. 2.
figure 2

Proposed model architecture: (a) the first learning stage and (b) the following learning stages.

Fig. 3.
figure 3

Top layers of ResNet: based on (a) fc layer or (b) \(conv_{1\times 1}\) layer. Both are functionally equivalent.

During learning the feature extractor f of \(\theta _s\) and the classifier g of \(\theta _n\), inverse function h of g (\(h=g^{-1}\)) can be approximately modeled by minimizing the L2 distance between the latent vector \(z \in Z\) and its reconstruction h(g(z)) like Fig. 2(a). Without any constraints, minimizing the reconstruction loss easily makes the latent space Z to be trivial in terms of the information that Z can represent such that \(\mathbf H (Z)\) which is an entropy of Z is low. Since Z should be informative enough to minimize the task solving loss \(L_n(\theta )\), joint learning with both the reconstruction and task solving losses prevents Z from being trivial. It is known that minimizing the conditional entropy \(\mathbf H (Z \vert Y_n)\) can be done by minimizing the reconstruction error of Z under the auto-encoder framework [14]. And minimizing the task solving loss \(L_n(\theta )\) keeps \(\mathbf H (Z)\) not to reduce too much. As a result, Z and \(Y_n\) are being mutually informative from the joint learning with the two losses.Footnote 2

Figure 2 shows the proposed model architecture. In the first stage, f, g, and h (respectively parameterized by \(\theta _s\), \(\theta _n\), and \(\theta _r\); initialized randomly in the first stage) are learned by minimizing the task solving and reconstruction losses concurrently. In the next stage, the parameters \(\theta _o\) and \(\theta _r\) of the functions \(g'\) and \(h'\) are restored from the \(\theta _n\) and \(\theta _r\) of the first stage and fixed during the rest of the learning stages.Footnote 3 \(Y_n\) and \(Y_o\) are the outputs for solving the task with current data and preserving previously-learned knowledge, respectively. Based on the loss function for LwF+ in Eq.(2), target Z space modeled in the first stage can be kept in the following stages by fixing \(\theta _o\) of \(g'\) and \(\theta _r\) of \(h'\) and guiding the output \(Y_o\) with LwF-logits. The loss function is shown below,

$$\begin{aligned} L(\theta ) = L_n(\theta ) + L_{LwF+}(\theta ) + L_{rec}(\theta ), ~~~~ L_{rec}(\theta ) = \lambda _{rec} L2(\theta ), \end{aligned}$$
(4)

where \(\lambda _{rec}\) is a weighting constant for the reconstruction loss. LwF-logits for \(Y_o\) are calculated in the same manner as in LwF+. \(\theta _s\) and \(\theta _n\) in the second stage are initialized with the parameters learned from the first stage and fine-tuned using the data in the corresponding stages until the end of the learning process.

Since we bound the Z space with the space modeled in the first stage and fix the \(\theta _o, \theta _r\) and \(Y_o\) (with LwF-logits), f tries to pull the new data examples into the modeled space which is remembering the previous data examples.

Fig. 4.
figure 4

Proposed model described in Fig. 2 based on the modified ResNet in Fig. 3.

4 Experiments

We compare the proposed method with the baseline models in several image classification tasks. Base network is ResNet [5] which consists of multiple residual blocks and average-pooling (avgpool) followed by a fully-connected (fc) layer as shown in Fig. 3(a). The 3-D feature map \(Z_{3d}\) extracted from the top-most residual block is pooled into a 1-D feature vector \(Z_{1d}\) via avgpool, and the output vector \(Y_{1d}\) is obtained from \(Z_{1d}\) through the final fc. Given \(z_{3d} \in Z_{3d}\) of an input example, \(y_{1d} \in Y_{1d}\) is given by \(g_{\theta _{fc}}(avgpool(z_{3d}))\), where g is the fc layer parameterized by \(\theta _{fc}\). g and avgpool are commutative because avgpool is a linear operation. Based on the modified model in Fig. 3(b), the output \(y_{1d}\) can be described as \(y_{1d} = avgpool(g_{\theta _{{conv}_{1\times 1}}}(z_{3d}))\), where g is now an 1\(\times \)1 convolution layer (\(conv_{1\times 1}\)) parameterized by \(\theta _{{conv}_{1\times 1}}\). We used the modified ResNet in order to model the approximate inverse function h accurately before avgpool. Both are equivalent in terms of their function, but the modified model requires more computation than the original ResNet. The proposed network architecture is shown in Fig. 4. \(\theta _n\) and \(\theta _o\) are the model parameters of \(conv_{1\times 1}\) layers which are the replacement of fc layers in the original ResNet.

Three datasets are used for experimental validation; CIFAR-10/100 [9] and chest X-rays (CXRs) for natural image and medical image classification. ResNet-56, 110, 21 are the base models for CIFAR-10, CIFAR-100, and CXRs, respectively. Each network consists of an initial convolution layer, three sets of N consecutive residual blocks, and a final \(conv_{1\times 1}\) layer. In ResNet-21, an additional convolution layer (kernel 3 \(\times \) 3, filter width 32, stride 2) with maxpooling (kernel 2\(\times \)2, stride 2) is added as conv-bn-relu-maxpool (bn: batch normalization [7], relu: rectified linear unit [10]) before the initial convolution to expand receptive field for large-size CXRs. Table 1 summarizes the layer components. The top layer of ResNet-21 is modified from its original architecture and this will be explained in Sect. 4.2. Approximate inverse function h (of g) parameterized by \(\theta _{r}\) in Fig. 4 consists of multiple consecutive convolutions. h in ResNet-56, 110, 21 for CIFAR-10, 100, CXRs includes four, three, three consecutive 3\(\times \)3 (stride 1) convolution layers with filter widths (64, 128, 128, 64), (256, 256, 256), (32, 64, 128) followed by a single bn-relu, respectively.

Table 1. Layer components. N, C, R are # of residual blocks, a conv layer, a residual block, respectively; e.g., \(R_1\) of ResNet-110 has 18 # of two consecutive 3\(\times \)3 conv layers with filter width 64. Downsampling with stride 2 is performed by \(R_2\) and \(R_3\).

For CIFAR-10/100, the initial learning rate of 0.1 is decayed by \(\frac{1}{10}\) every 40 epochs until the 120-th epoch. For CXRs, the initial learning rate of 0.01 is decayed by \(\frac{1}{10}\) every 20 epochs until the 80-th epoch. Weight decay constant of 0.0001 and stochastic gradient descent with momentum 0.9 are used. For CIFAR-10/100, 32 \(\times \) 32 image is randomly cropped from 40 \(\times \) 40 zero-padded image (4 pixels on each side of the original 32\(\times \)32 image) during training [5]. Each CXR is resized to 500\(\times \)500 and randomly cropped 448\(\times \)448 image is used for training. \(\lambda _{EWC}\) for CIFAR-10, CIFAR-100, and CXRs are 0.1, 10.0, and 1.0, respectively. They are selected from the set \(\{\)0.1, 1.0, 10.0\(\}\) by cross validation. \(\lambda _{LwF}\) in Eq. (1) is \(\frac{0.1}{K-1}\), where K is the number of learning stages including the current one. \(\lambda _{LwF+}\) and \(\lambda _{rec}\) are 0.1 and 1.0. All experiments are done with tensorflow [1].

4.1 CIFAR-10/100

CIFAR-10/100 have 10/100 classes with 32 \(\times \) 32 50k/10k training/test images, respectively. In our experiment, 10k training images are used for validation and the model which performs the best on the validation set is selected for evaluation on the test set. The remaining 40k training images are splitted into four sets (10k/set). Each model is trained continually in the multi-center single-task learning set-up, where each center has 10k training images and the task is 10/100-class classification. Table 2 shows the error rates on the test set with mean (std) of five trials. LwF+, EWCLwF+ mostly perform better than LwF, EWCLwF; i.e. LwF+, EWCLwF+ are more appropriate for the multi-center single-task learning. The proposed method performs the best as shown in this table.

Table 2. CIFAR-10/100: test set (10k images) error rates - mean (std) of five trials.

After stage-1, training data of the stage-1 (st-1-trn) is not used in the following stages anymore. So, we evaluate the final model with st-1-trn to see how much of st-1-trn has been forgotten after the final stage. For CIFAR-10, 85.75%, 85.97%, 88.64%, 88.22%, 89.40% of st-1-trn are still preserved as correct at stage-4 for FT, EWC, LwF+, EWCLwF+, Proposed, respectively. For CIFAR-100, 58.67%, 58.85%, 65.57%, 66.91%, 69.34% of st-1-trn are preserved correctly at the final stage (with the same ordering).

4.2 Chest X-Rays for Tuberculosis

We experiment with a real-field medical dataset in order to verify the proposed method is also valid in a practical set-up. A total of 10,508 de-identified CXRs (from the Korean Institute of Tuberculosis [6]) are used. It consists of 3,556 abnormal (tuberculosis; TB) and 6,952 normal cases. CXRs are commonly used for screening TB. The cases which require a follow-up test are recalled by radiologists. Among the 3,556 abnormal cases, 1,438 cases were diagnosed as active TB (TB-A) at the screening stage. The status of the remaining 2,118 cases which needed a follow-up sputum test could not be specified radiologically at the screening stage (TB-U). 80% of the data are randomly selected for training and divided into four sets; 288(TB-A), 424(TB-U), 1390(Normal) per each set. The remaining 20% are splitted evenly for validation and test; 143(TB-A), 211(TB-U), 696(Normal) for each set.

We modified the output layer of the model in order to exploit the status information of abnormality. Two output \(conv_{1\times 1}\) layers are used for 2-class (TB vs normal) and 3-class (TB-A, TB-U, and normal) classification, respectively. The 3-class \(conv_{1\times 1}\) is used for knowledge preservation. The 2-class \(conv_{1\times 1}\) is just for the performance measurement (AUC; area under ROC curve).

Table 3 summarizes AUC of each model with mean (std) of five trials. Except for the first stage, the proposed method is always better than the others. The proposed method also performs the best in terms of the ensemble performance of the five trials; 0.9257, 0.9205, 0.9217, 0.9271, 0.9228, 0.9172, 0.9363 for FT, EWC, LwF, LwF+, EWCLwF, EWCLwF+, Proposed, respectively. Figure 5 is the ROC curves of the st-1-trn at stage-4 (similar to CIFAR-10/100), which implicitly shows that the proposed method is helpful to preserve old knowledge.

Table 3. CXRs for TB: test set AUC - mean (std) of five trials.
Fig. 5.
figure 5

ROC curves at stage-4 with stage-1 training data.

5 Conclusion

In this work, we raise the problem of catastrophic forgetting in multi-center single-task learning environment and propose a new way to preserve old knowledge in neural networks. By modeling the high-level feature space to be appropriate for knowledge preservation in the first stage and constraining the feature space to be in the modeled space during training in the following stages, we can preserve the knowledge learned in preceding stages. The proposed method is shown to be beneficial in terms of kee** the old knowledge in classification tasks. We need more experimental analysis beyond the classification such as lesion detection or segmentation, and we leave this for future work.