1 Introduction

We aim to provide whole heart segmentation in cardiac MRI for patients with congenital heart disease (CHD). This involves delineating the heart chambers and great vessels [1], and promises to enable patient-specific heart models for surgical planning in CHD [2]. CHD encompasses a vast range of cardiac malformations and topological changes. Defects can include holes in the heart walls (septal defects), great vessels connected to the wrong chamber (e.g., double outlet right ventricle; DORV), dextrocardia (left-right flip), duplication of a great vessel, a single ventricle, and/or prior surgeries creating additional atypical connections. In MRI, different chambers and great vessels locally appear very similar to each other, and there is little or no contrast at the valves and thin walls separating neighboring structures. Finally, labeled training data is very limited. This precludes modeling each CHD subtype separately in an attempt to reduce variability. Moreover, patients with unique combinations of defects and prior surgeries defy categorization. Beyond our application, limited training data is to be expected for new applications of medical imaging not yet in widespread clinical practice. This necessitates development of methods that generalize well from small, imbalanced datasets, possibly also incorporating user interaction.

State-of-the-art methods use a convolutional neural network (CNN) to directly outline all chambers and vessels in one step [3, 4]. However, CNNs for CHD have largely been limited to segmenting the blood pool and myocardium [5, 6]. Direct co-segmentation of all major cardiac structures works well when applied to adult-onset heart disease, which induces much less severe shape changes compared to CHD. However, it fails completely on held-out subjects with severe CHD malformations after training with our small dataset of CHD patients.

We develop an iterative segmentation approach that evolves a segmentation over several steps in a prescribed way and automatically estimates when to stop, beginning from a single seed for each structure placed by the user. An iterative method can operate more locally, better maintain each structure’s connectivity, and propagate information from distant landmarks, similar to traditional snakes, level sets and particle filters [7]. We employ a recurrent neural network (RNN) [8], which uses context to grow the segmentation appropriately even in areas of low contrast. Deep learning research has indeed focused on segmenting a single image iteratively. Examples include recursive refinement of the entire segmentation map [9,

2 Iterative Segmentation Model

Given an input image \({\mathbf {x}}\) defined on the domain \(\varOmega \), we seek a segmentation label map \({\mathbf {y}}\) that assigns one of L anatomical labels to each voxel in \({\mathbf {x}}\).

Generative Model: We model the segmentation \({\mathbf {y}}\) as the endpoint of a sequence of segmentations \({\mathbf {y}_0}, \ldots , {\mathbf {y}_T}\), where \({\mathbf {y}_t}: \varOmega \rightarrow \{1, \ldots , L\}\) for time steps \(t=0, \ldots , T\). The intermediate segmentations \({\mathbf {y}_t}\) capture a growing part of the anatomy of interest. In practice, the initial segmentation map \({\mathbf {y}_0}\) is created by centering a small sphere around an initial seed point placed by the user.

The number of iterations required to achieve an accurate segmentation depends on the shape and size of the object being segmented. To capture this, we introduce a sequence of indicator variables \({s_0}, \ldots , {s_T}\), where \({s_t}\in \{0, 1\}\) specifies whether the segmentation is completed at time step t. If \({s_t}= 1\), then \({\mathbf {y}_t}\) is deemed the final segmentation and we set \({\mathbf {y}_i}= {\mathbf {y}_{i-1}}\) and \({s_i}= 1\) for all \(i > t\).

Given an image and an initial segmentation, the inference task is to compute \(p({\mathbf {y}_T}, {s_T}\vert {\mathbf {x}}, {\mathbf {y}_0}, {s_0}= 0)\). We assume that the segmentations \(\{ {\mathbf {y}_t}\}\) and stop** indicators \(\{ {s_t}\}\) follow a first order Markov chain given the input image:

$$\begin{aligned} p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_0}, \ldots , {\mathbf {y}_{t-1}}, {s_0}, \ldots , {s_{t-1}}) = p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}), \end{aligned}$$
(1)
$$\begin{aligned} p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_0}, {s_0}) = \sum _{{\mathbf {y}_{t-1}}} \sum _{{s_{t-1}}} p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}) \cdot p({\mathbf {y}_{t-1}}, {s_{t-1}}\vert {\mathbf {x}}, {\mathbf {y}_0}, {s_0}). \end{aligned}$$
(2)

Transition Probability Model: We must define the transition probability \(p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}})\) to complete the recursion in Eq. (2). There are two possible cases: \({s_{t-1}}= 1\) and \({s_{t-1}}= 0\). Based on the definition of \({s_{t-1}}\), we obtain

$$\begin{aligned} {\begin{matrix} p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}= 1)&= {\mathbb {1}}({\mathbf {y}_t}= {\mathbf {y}_{t-1}}) \cdot {\mathbb {1}}({s_t}= 1), \end{matrix}} \end{aligned}$$
(3)

where \({\mathbb {1}}(\cdot )\) denotes the indicator function. To compute \(p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}= 0)\), we introduce a latent representation \({\mathbf {h}_t}= h({\mathbf {x}}, {\mathbf {y}_{t-1}})\) that jointly captures all of the necessary information from image \({\mathbf {x}}\) and previous segmentation \({\mathbf {y}_{t-1}}\). Intuitively, predicting whether the segmentation \({\mathbf {y}_t}\) is complete given \({\mathbf {x}}\) can be performed by examining whether \({\mathbf {y}_{t-1}}\) is “almost” complete. Therefore, the segmentation \({\mathbf {y}_t}\) and stop** indicator \({s_t}\) are conditionally independent given \({\mathbf {h}_t}\):

$$\begin{aligned} p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}= 0)= p({\mathbf {y}_t}, {s_t}\vert {\mathbf {h}_t}) = p({\mathbf {y}_t}\vert {\mathbf {h}_t}) \cdot p({s_t}\vert {\mathbf {h}_t}). \end{aligned}$$
(4)

We model the function \(h({\mathbf {x}}, {\mathbf {y}_{t-1}})\) and distributions \(p({\mathbf {y}_t}\vert {\mathbf {h}_t})\) and \(p({s_t}\vert {\mathbf {h}_t})\) as stationary; they do not depend on the time step t.

Learning: We learn a representation of \(p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}= 0)\) given a training dataset of example desired trajectories of segmentations. Specifically, we consider a training dataset \(\mathcal {D}\) of N images \(\{{\mathbf {x}^i}\}_{i=1}^N\), each of which has a corresponding sequence of segmentations \({\mathbf {y}_0^i}, \ldots , {\mathbf {y}^i_{T_i}}\) and of stop** indicators \({s_0^i}, \ldots , {s^i_{T_{i}}}\), where \({s_0^i}= \ldots = {s^i_{T_{i-1}}}= 0\) and \({s^i_{T_{i}}}= 1\). The parameter values to be determined are \({\varvec{\theta }}= \{{\varvec{\theta }_h}, {\varvec{\theta }_y}, {\varvec{\theta }_s}\}\) corresponding to \(h({\mathbf {x}}, {\mathbf {y}_{t-1}}; {\varvec{\theta }_h})\), \(p({\mathbf {y}_t}\vert {\mathbf {h}_t}; {\varvec{\theta }_y})\), and \(p({s_t}\vert {\mathbf {h}_t}; {\varvec{\theta }_s})\), respectively. We seek the parameter values that minimize the expected negative log-likelihood of the output segmentation and stop** indicator sequences given the image and initial conditions, i.e., \({\varvec{\theta }^{*}}= {\mathop {\mathop {\mathrm {argmin}}}\nolimits _{\varvec{\theta }}} \mathcal {L}({\varvec{\theta }})\),

$$\begin{aligned} \mathcal {L}({\varvec{\theta }})&= \mathbb {E}_{{\mathbf {x}}, {\mathbf {y}_0}, \ldots , {\mathbf {y}_T}, {s_0}, \ldots , {s_T}\sim \mathcal {D}}\Big [- \log p({\mathbf {y}_1}, \ldots , {\mathbf {y}_T}, {s_1}, \ldots , {s_T}\vert {\mathbf {x}}, {\mathbf {y}_0}, {s_0}; {\varvec{\theta }}) \Big ] \nonumber \\&= -\mathbb {E}\Big [ \sum _{t=1}^{T} \log p({\mathbf {y}_t}\vert h({\mathbf {x}}, {\mathbf {y}_{t-1}}; {\varvec{\theta }_h}) ; {\varvec{\theta }_y}) + \log p({s_t}\vert h({\mathbf {x}}, {\mathbf {y}_{t-1}}; {\varvec{\theta }_h}) ; {\varvec{\theta }_s}) \Big ]. \end{aligned}$$
(5)
Fig. 1.
figure 1

Iterative segmentation as an RNN. (a) Generative model. (b) The RNN uses the same augmented U-net at each step to predict the next segmentation and stop** indicator. (c) Architecture details (conditioning dropped for clarity).

Note that teacher forcing has lead to decoupled time steps. The first and second terms in the likelihood above penalize differences for the segmentations and the stop** indicators, respectively, between the predicted probabilities and the ground truth. In practice, we perform class rebalancing for both terms, and further supplement the segmentation loss by more strongly weighting pixels on the boundaries of the ground truth segmentation.

Inference: Computing \(p({\mathbf {y}_T}, {s_T}\vert {\mathbf {x}}, {\mathbf {y}_0}, {s_0}= 0)\) via the recursion in Eq. (2) is intractable due to the summation over all possible segmentations \({\mathbf {y}_{t-1}}\). To approximate, we follow a widely accepted practice of using the most likely segmentation \({\mathbf {y}_{t-1}^{*}}\) and stop** indicator \({s^{*}_{t-1}}\) as input to the subsequent computation:

$$\begin{aligned} p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_0}, {s_0}= 0 ; {\varvec{\theta }})&\approx p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}^{*}}, {s^{*}_{t-1}}; {\varvec{\theta }}),\nonumber \\ \quad \text {where } {\mathbf {y}_{t-1}^{*}}, {s^{*}_{t-1}}&= \mathop {\mathrm {argmax}}_{{\mathbf {y}_{t-1}}, \, {s_{t-1}}} p({\mathbf {y}_{t-1}}, {s_{t-1}}\vert {\mathbf {x}}, {\mathbf {y}_0}, {s_0}= 0 ; {\varvec{\theta }}). \end{aligned}$$
(6)

The segmentation is fully automatic given the initial seed. If the stop** indicator is predicted incorrectly, a user can manually override it by asking for more iterations or by choosing a segmentation from a previous step.

RNN: We implement our iterative segmentation model as an RNN (Fig. 1), which is formed by connecting identical copies of an augmented 3D U-net [17] trained to estimate \(p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}= 0)\). Thus, parameters are shared both spatially and temporally. At each step, the U-net inputs the image and the most likely segmentation from the previous step. This respects the Markov property in Eq. (1), unlike if any hidden layers were connected between successive steps. If the stop** indicator \({s_t^{*}}= 1\), the segmentation propagation halts.

Our augmented U-net modeling \(p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}= 0)\) has \(L+1\) input channels, containing the input image and a binary map for each of the L labels in the segmentation \({\mathbf {y}_{t-1}}\) (including the background). There are two outputs: the probability map for the segmentation \({\mathbf {y}_t}\) (at each voxel, representing the parameters of the categorical distribution over L labels), and the Bernoulli stop** parameter \(p({s_t}= 1 \vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}= 0)\). Jointly predicting the segmentation and stop** indicator enables a smaller model compared to two separate networks.

The original U-net for image segmentation produces a final set of C learned feature maps, which undergo \(C \cdot L\) \(1 \times 1 \times 1\) convolutions and a softmax activation to give the output segmentation probabilities. We use these C learned feature maps as the latent joint representation \({\mathbf {h}_t}= h({\mathbf {x}}, {\mathbf {y}_{t-1}}; {\varvec{\theta }_h})\). The U-net parameters can therefore be split into two sets. The parameters for the final \(1 \times 1 \times 1\) convolutions are \({\varvec{\theta }_y}\) of \(p({\mathbf {y}_t}\vert {\mathbf {h}_t}; {\varvec{\theta }_y})\), and the remainder are \({\varvec{\theta }_h}\) of \(h({\mathbf {x}}, {\mathbf {y}_{t-1}}; {\varvec{\theta }_h})\). The probability \(p({s_t}= 1 \vert {\mathbf {h}_t}; {\varvec{\theta }_s})\) is computed by applying C additional \(3 \times 3 \times 3\) convolutions with parameters \({\varvec{\theta }_s}\) to the feature maps in \({\mathbf {h}_t}\), followed by a global average and sigmoid activation to yield a scalar in \(\{0,1\}\).

Generating Segmentation Trajectories: Our training dataset of images and segmentation trajectories is derived from a collection of paired images and complete segmentations. Several acceptable trajectories exist for each pair, e.g., starting from different initial seeds. To this end, at the beginning of each epoch a random tuple \(({\mathbf {y}_{t-1}}, {\mathbf {y}_t}, {s_t})\) is generated for each image. These tuples all follow the same principle that we want the network to learn.

As a concrete example, the trajectories used in our experiments are as follows. For the aorta, the segmentation grows from the seed along the vessel centerline, by a random distance to form \({\mathbf {y}_{t-1}}\) and an additional 10 pixels for \({\mathbf {y}_t}\). The seed is placed in the descending aorta, and the endpoint is at the valve where the aorta connects to a left or right ventricle. This seed could be automatically detected in the future, and the lack of contrast at the valve provides a challenging test case for our automatic stop**. For the left ventricle, we randomly place the seed in the center region of the chamber, and perform a random number of dilations to form \({\mathbf {y}_{t-1}}\), and 3 more dilations to form \({\mathbf {y}_t}\).

Data Augmentation: Data augmentation is essential to prevent overfitting on a small training dataset. We mimic the diversity of heart shapes and sizes, global intensity changes caused by inhomogeneity artifacts, and noise induced by elevated heart rates or arrhythmias. We apply random rigid and nonrigid transformations, random constant intensity shifts, and random additive Gaussian noise. We also investigate including random left-right (L-R) and anterior-posterior (A-P) flips, to better handle dextrocardia or other cardiac malpositions, since in these cases the left ventricle may lie on the right side of the body.

If the augmented U-net for \(p({\mathbf {y}_t}, {s_t}\vert {\mathbf {x}}, {\mathbf {y}_{t-1}}, {s_{t-1}}= 0)\) is trained solely using error-free segmentations \({\mathbf {y}_{t-1}}\), then it may not operate well on its own imperfect intermediate results at test time. We increase robustness by performing additional data augmentation on the input segmentations \({\mathbf {y}_{t-1}}\). We corrupt these segmentations by applying random nonrigid deformations, and by inserting random blob-like structures that vary in number, location and size and are attached to the segmentation foreground or free-floating. Since the target segmentation \({\mathbf {y}_t}\) remains unchanged, the model learns to correct mistakes in its input.

3 Experimental Validation

We evaluate our iterative segmentation and tailored direct segmentation methods, focusing on segmenting the aorta and left ventricle (LV) of CHD patients.

Data: We use the HVSMR dataset of 20 MRI scans from patients with a variety of congenital heart defects [18]. Each high-resolution (\(\approx \)0.9 mm\(^3\)) 3D image was acquired on a 1.5 T scanner (Philips Achieva), without contrast agent and using a free-breathing SSFP sequence with ECG and respiratory navigator gating. The HVSMR dataset includes blood pool and myocardium segmentations only. A trained rater manually separated all of the heart chambers and great vessels. The 20 images were categorized after visually assessing any gross morphological malformations: 4/20 severe (prior major reconstructive surgery, single ventricle, dextrocardia), 5/20 moderate (DORV, VSD, abnormal chamber shapes), and 11/20 mild (ASD, stenosis, etc.). The dataset was randomly split into 4 folds for cross-validation (15 training, 5 testing), with an equal number of mild, moderate and severe cases in each. Input images were resized to \(\approx \)128 \(\times \) 180 \(\times \) 144.

Experiments: In our tests, binary segmentation of each structure outperformed co-segmenting all of the heart chambers and vessels. We trained several models aimed at segmenting the aorta and left ventricle of CHD patients. DIR uses a single U-net to perform direct binary segmentation. DIR-DIST includes the Euclidean distance to the initial seed as an additional input channel. ITER (stop) is iterative segmentation using our RNN with automatic stop**, and ITER (max) simulates a user by choosing the segmentation with the best Dice coefficient after 30 iterations of our RNN. Finally, ITER-SEG-ABL is an ablation study with no data augmentation on the input segmentations. We tuned the architectural parameters for each experiment separately, nevertheless resulting in similar networks. All U-nets had 3 levels, 24 feature maps at the first level, and \(\approx \)870,000 parameters. The best network for direct segmentation of the aorta used \(2 \times 2 \times 2\) max pooling (receptive field = \(40^3\)), while all others used \(3 \times 3 \times 3\) max pooling (receptive field = \(68^3\)). For training, optimization using adadelta ran for 2000 epochs with a batch size of 1. For iterative segmentation, the \(\mathrm {argmax}\) in Eq. (6) is computed per voxel, by assuming that the segmentation of each voxel is conditionally independent of all other voxels given \({\mathbf {h}_t}\). Segmentations were post-processed to keep only the largest island or the island containing the initial seed, for experiments in which this improves overall accuracy. Aorta segmentations were not penalized for descending aortas longer than in the gold-standard.

Fig. 2.
figure 2

Aorta (AO) and LV segmentation validation. DIR-DIST is the best direct segmentation method, but iterative segmentation generalizes better to severe subjects. Top: Dice coefficients for all methods. Bottom: Results for all 20 subjects, sorted by DIR-DIST score and with severe subjects highlighted in green. (Color figure online)

Results: Figures 2 and 3 report the results. There was no notable difference in accuracy between the mild and moderate groups. DIR-DIST was the best direct segmentation method, demonstrating the advantage of leveraging user interaction. For all methods, incorporating L-R and A-P flips in the data augmentation improved performance for severe subjects. Iterative segmentation stopped automatically after \(18 \pm 3\) steps for both the aorta and the LV, requiring \(\approx \)15 s. The potential benefits of our iterative segmentation approach are demonstrated by the performance of ITER (max), which shows improvement for all of the severe cases while maintaining accuracy for the others. The stop** prediction is not perfect at test time: the number of iterations separating the automatic stop** point from the best segmentation in a sequence was \(0.8 \pm 1.0\) iterations for the aorta and \(3.0 \pm 2.5\) iterations for the LV. The sole aorta containing a stent was poorly segmented by all methods (Fig. 3e). The stent caused a strong inhomogeneity artifact that the iterative segmentation could not grow past, and the stop** criterion was never triggered.

Fig. 3.
figure 3

Representative aorta and LV segmentations in held-out subjects with severe CHD. Arrows illustrate both the benefits and failure cases of iterative segmentation with automatic stop**, where it (a) successfully segments a difficult case, (b) stops too late, (c) correctly stops near a valve, (d) avoids growing through a septal defect, (e) cannot grow through a dark region caused by a stent.

4 Conclusions

We presented an iterative segmentation model and its RNN implementation. We showed that for whole heart segmentation, the iterative approach was more robust to the cardiac malformations of severe CHD. Future work will investigate the potential general applicability of iterative segmentation when one is restricted to a small training dataset despite wide anatomical variability.