
Machine learning has emerged as an effective approach for develo** predictive models for high-throughput screening of materials1,2,3,4,5,6,7,8. For example, machine-learned models for formation energy prediction can construct a convex hull for a rapid assessment of the thermodynamic stability of compounds at a fraction of the computation cost and time needed for density functional theory (DFT)-calculated convex hulls with reasonable accuracy9. In materials research, a machine learning model can be characterized by two aspects; the representation of the material as a readable entity (or input) to the learning algorithm and the learning algorithm itself. Several machine learning approaches have investigated a variety of representations as simple as a pool of physicochemical attributes (e.g., atomic number, cohesive energy, band gap, and heat of melting), and composition vectors10,11,12,13,14,15,16,17 up to more advanced graph representations of composition and structure of crystal compounds19,20,21,22,23,24,25. The use of image representations for machine learning, however, has been less explored in the materials research community. Image representation can be especially useful because of the significant advancements that have been made in pattern recognition (or representation learning) of visual images in the field of computer vision (a field of computer science that deals with processing and understanding visual data like images or videos). These advancements are largely because of the evolution towards more sophisticated architectures of convolutional neural networks (e.g., Residual Neural Network (ResNet)26, EfficientNet27, U-Net28) which has enabled adopting increasingly deeper networks. Inspired by this untapped opportunity for materials representation learning, we develop a sparse voxel image representation of crystalline materials that is input into a very deep convolutional neural network (CNN) with a sophisticated architecture inspired by ResNet.

We use the formation energy prediction of crystalline compounds as a platform for demonstrating the performance of our deep-learning model on voxel images of crystals. Formation energy is an ideal platform because large databases of DFT-calculated formation energies are available (e.g., Materials Project29 and AFLOW (Automatic Flow)30), which provide the large amount of data needed for training our deep CNN. Additionally, there are several available machine learning approaches for formation energy prediction with which we compare the performance of our model. We show that our model’s formation energy predictive performance is comparable to the state-of-the-art machine learning models’ prediction. We present a thorough comparison of 3115 binary convex hulls constructed from our model’s formation energy against DFT-calculated binary convex hulls in the Materials Project database. By introducing multiple error metrics for assessing binary convex hulls, we showcase how the error in the formation energy prediction is projected into the performance of a predicted convex hull.

Among machine learning methods for formation energy prediction of crystal compounds, graph neural networks have shown promising performance because the graph data structure can efficiently capture the physical, compositional, and structural information of crystal compounds19,20,21,22,23,24,25.

In our approach, we adopt the architecture of residual blocks to construct a 15-layer CNN with 7 skip connections. The overall architecture, as depicted in Fig. 1, consists of a deep CNN followed by a fully connected neural network for the prediction of formation energy using sparse voxel images of crystals. The deep CNN part of the architecture is employed for feature learning of voxel crystal images. These learned features are then flattened and passed as input to the fully connected neural network, which performs the final prediction of the formation energy. In our network design, we deliberately delay the introduction of pooling layers in our CNN. The first pooling layer is introduced only after the fifth convolutional kernel, with subsequent pooling layers added after the eleventh and fifteenth kernels, respectively. A detailed description of our CNN architecture can be found in Methods. In the context of materials representation learning, the use of skip connections in our CNN allows for the bypassing of local atomic features discovered in the shallower layers, while progressively learning more global features of crystal compounds across the layers of the deep network. This hierarchical learning approach facilitates the extraction of relevant abstractions, enabling the model to capture both local and global features within the crystal structures.

Our CNN, inspired by the ResNet architecture described in ref. 26, incorporates slight modifications to better suit our specific task. In contrast to the original design, we choose not to adopt the batch normalization technique in our residual blocks. This decision is based on the observation that batch normalization hampers the training of our CNN, likely due to the intrinsic differences between sparse crystal images and natural images (such as those in ImageNet41). Consequently, the batch normalization process may not yield the intended benefits for our crystal image representation. Furthermore, we adjust the way in which we handle the number of channels within our network. Instead of doubling the number of channels after each convolution layer, as outlined in the original ResNet design, we increase the number of channels, after each pooling, by concatenating the side skip connections with the output of the convolution layer. This alternative approach allows for a more effective utilization of information from both the skip connections and the convolutional layers, promoting better feature representation within our network. By tailoring the ResNet-inspired architecture to the characteristics of our crystal images, we optimize the training process and enhance the performance of our CNN for the specific task of crystal compound formation energy prediction.

Data Sets

We obtained a data set of 139,367 crystal structures along with their corresponding DFT-calculated formation energies (the target variables) from Materials Project (v2021.05.13)29. From this, 15,354 structures are excluded because they either require a high resolution or a large image (more details in Methods). To train our model, we split the data into train (60%), validation (20%), and test (20%) sets. During the data pre-processing stage, we removed 9175 crystal structures from the train set that either contain two atoms occupying the same voxel or have a unit cell that does not fit in the 17-Å cubic box, as described in detail in Methods. During training, we employ data augmentation by randomly rotating each crystal image before feeding it into the model at each epoch (see Supplementary Fig. S1). This technique helps alleviate overfitting (see Supplementary Fig. S4) and enhances the predictive performance of our model. Data augmentation is particularly beneficial as it effectively increases the size of the train data and implicitly enforces the rotation-invariance of crystal compounds with respect to their formation energy, as explained further below. To monitor the training process and prevent overfitting, we use predictions on the validation data. Once the model is trained, we evaluate its overall performance using the test data, as outlined below. In  the Discussion section, we delve into the significance of data augmentation and skip connections in our CNN architecture, highlighting their role in improving the model’s performance.

Formation Energy Prediction Assessment

In this section, we examine the performance of our model’s prediction. As detailed in  Methods, we employ an ensemble averaging technique for predicting the formation energy. Figure 2a shows the parity plot of the formation energy prediction of our model against the DFT-calculated formation energies on both the train and test sets. The results indicate an MAE of 0.042 eV per atom and 0.046 eV per atom on the train and test sets, respectively. Over 89% of the samples in the test set exhibit absolute errors below 0.1 eV per atom, and only about 2% of the samples have absolute errors exceeding 0.2 eV per atom (see Supplementary Fig. S2b). The formation energy prediction error (i.e., predicted formation energy - DFT formation energy) shows a slightly positive skew normal distribution with a median and mean value of 0.003 eV per atom and -0.003 eV per atom on the test set (see Supplementary Fig. S2b). As shown in Fig. 2b, c, our model tends to exhibit higher errors for crystal compounds with more positive and larger formation energies. This trend has also been observed in other studies16,42. To exemplify this trend, we analyze four equally populated subsets of our test set sorted by the formation energy with respective formation energy ranges of (−4.47, −2.39), (−2.39, −1.47), (−1.47, −0.46), and (−0.46, 5.33) eV per atom with calculated MAEs of 0.037, 0.039, 0.046, and 0.064 eV per atom, respectively. The relatively diminished prediction performance observed for larger, positive-value ranges of formation energy can be attributed to an inherent bias in the existing dataset. The data available in the Materials Project predominantly comprises chemically stable structures characterized by negative formation energies. In contrast, the occurrence of chemically unstable crystal structures with positive formation energies remains a minority within this dataset. Notably, less than 10% of all samples possess positive formation energy (see Supplementary Fig. S2a). Pandey et al.43 have elucidated how this disparity in data distribution impacts the model’s predictive capabilities.

Fig. 2: Formation energy prediction evaluation.
figure 2

a The parity plot for samples in the train and test sets. The MAE of formation energy prediction for the test and train data is reported in the legend. b, c Distribution of the prediction error of test data over different ranges of formation energy. b Box and whisker representation of prediction error (i.e., predicted Ef - DFT Ef) for different intervals of DFT formation energy. The left side, middle line, and right side of each box show respectively the first quartile, median, and third quartile of the error. The whisker line shows the minimum and maximum of the error. c The scatter plot of samples in the test set showing the DFT formation energy versus prediction error.

We conducted a comparative analysis of our model's predictive performance with state-of-the-art machine learning models, including ElemNet16 and Roost (Representation Learning from Stoichiometry)17 as the best models based on compositional features, and ALIGNN23 and CGCNN as the best models based on graph representations of crystal structures. Traditionally, skip connections are recognized for their role in alleviating optimization challenges by producing smoother loss functions, facilitating easier training52. However, our work sheds light on an additional aspect of skip connections beyond their optimization benefits. We demonstrate that skip connections serve as a mechanism to capture the essential physicochemical information at different levels. By allowing the outputs of different layers (both shallow and deep) to bypass through identity mapping, skip connections enable the network to leverage local atomic fingerprints from shallower layers while simultaneously learning abstract, generalized features from deeper layers. In this way, skip connections facilitate the integration of both local and global information, leading to improved performance in formation energy prediction.


Data collection and voxel image preparation

We gather crystal structure information in Crystallographic Information File (CIF) format and the corresponding DFT-calculated formation energies from the Materials Project database (v2021.05.13)29. To extract the structural information, we utilize the Atomic Simulation Environment (ASE) package55. Our in-house Python code is then employed to generate sparse voxel images of the crystals. In the voxelization process, we repeat the crystal unit cell (cubic or non-cubic) in space to fill a cubic box with an edge size of 17 Å. We eliminate a crystal structure if its unit cell does not fit in the cubic box. The box is then voxelized using a 32 × 32 × 32 grid, resulting in images with dimensions of 32 × 32 × 32 voxels. To ensure that each voxel contains at most one atom, we set the minimum interatomic distance to be greater than the diagonal of a voxel, dv, calculated as \({d}_{v}=(17/32)\times \sqrt{3}=0.92\) Å. Consequently, crystal structures with minimum interatomic distances smaller than 0.92 Å are filtered out. The 3D sparse voxel images of crystals are color-coded using three channels, similar to an RGB image. These channels represent the normalized atomic number, group number, and period number. For lanthanides and actinides, we assign a group number of 3.5. During training, to introduce variability and enhance generalization, we apply a random rotation to each crystal image at each epoch. Rather than applying a direct rotation to the unit cell and subsequently executing the computationally intensive task of filling the 17 Å box - a method which becomes intractably repetitive - we initially construct a larger ‘encompassing’ box with an edge equal to the diagonal of the 17-Å cubic box. During the data pre-processing stage, we fill the larger box by replicating the crystal unit cell in all directions only once. Consequently, whenever an instance of an crystal structure input is requested, either for training or prediction, we perform a random rigid-body rotation to the larger box, while the 17-Å box remains unchanged and consistently populated after each rotation. Thereafter, we perform the voxelization of the 17-Å box to generate the final sparse voxel images. Supplementary Fig. S1 visually details the rotation methodology.

Convolutional neural network

We develop a 15-convolutional-layer network consisting of 7 residual blocks and 3 average pooling layers, followed by a fully connected neural network (see Supplementary Fig. S11). Each residual block consists of two convolutional layers, each followed by a rectified linear unit (ReLU) activation layer and a skip connection that connects the beginning of the block to its end. In each convolutional layer, we use a kernel of size 3 and padding of type SAME with stride 1 to ensure that the filter is applied to all the voxels of the input. To merge a skip connection (i.e., side stream) with the mainstream coming from the convolutional layer, we either use addition or concatenation. We use the concatenation of outputs only before a pooling layer in order to double the number of channels while reducing the image size during pooling. The addition of outputs is used elsewhere as the method of merging in the residual blocks.

The deep convolutional network consists of three distinct segments, each containing a different image size and ending with a pooling layer. The first segment consists of a single convolutional layer, followed by an activation layer. In this layer, we increase the number of channels from 3 to 32. This single layer is followed by two residual blocks, each consisting of two convolutional and activation layers, outputting 32 channels. We utilize concatenation to combine the outputs of the mainstream and skip connection, rendering the number of channels of the output of this segment equal to 64. This segment ends with an average pooling layer, reducing the image size by half (16 × 16 × 16). The second segment consists of three residual blocks, followed by an average pooling. The images passing through this segment have 64 channels, and at the end of the segment, their size is reduced by half (8 × 8 × 8) and their channels are doubled (128). The last segment consists of two residual blocks and an average pooling layer, but in this case, the last block uses addition instead of concatenation, keeping the channels as 128 and reducing the size to 4 × 4 × 4. A detailed schematic of the network is shown in Supplementary Fig. S11.

The last pooling layer is flattened to a vector of size (4 × 4 × 4 × 128 = 8192) and is connected to a fully connected network with a node architecture of 16-16-1 with linear activation functions. The Keras package56 is used to build and train this network. The 3D images of the train set are randomly rotated in 3D space and input to the network for 500 epochs in batches of size 32. The mean squared error (MSE) is used as the loss function. To train the network, we use the Adam optimizer with a learning rate of 0.001, the exponential decay rates of 0.9 and 0.999 for the first and second moment estimates, respectively, and a machine precision threshold (or ϵ) of 1e-07.

Rotational ensemble averaging

Once the model is trained, we employ an ensemble averaging method for predictiong the formation energy. Once a crystal sample is input into the trained model, a ensemble of 50 randomly rotated instances of the sample is generated and the formation energy prediction is averaged over the ensemble. The ensemble averaging methods improves the prediction accuracy and robustness of our model, as detailed in the Results section Fig. 3, and Supplementary Figs. S6,S7.

Error metrics

The evaluation of the formation energy prediction and the constructed convex hull is performed using the following error metrics:

Formation Energy Mean Absolute Error (MAE): The MAE is calculated using the formula:

$$\,{{\mbox{MAE}}}\,=\frac{1}{n}\mathop{\sum }\limits_{i=1}^{n}| {y}_{i}-{\hat{y}}_{i}|$$

where yi represents the true formation energy of sample i (DFT-calculated formation energy obtained from the Materials Project database), \({\hat{y}}_{i}\) corresponds to the model’s prediction of the formation energy for sample i, and the sum runs over total of n samples. When computing the MAE for a binary convex hull prediction, only crystal compounds (or samples) from that specific binary system are included.

Depth error for Convex Hull: The depth error for the convex hull measures the difference in the confined area between the predicted and true convex hulls, and is defined as:

$$\,{{\mbox{Depth error}}}\,=\frac{{A}_{{{{{{{{\rm{predicted}}}}}}}}}-{A}_{{{{{{{{\rm{true}}}}}}}}}}{{A}_{{{{{{{{\rm{true}}}}}}}}}}$$

where Apredicted and Atrue represent the areas enclosed by the predicted and true (or DFT-calculated) convex hulls, respectively.

Accuracy of Convex Hull Prediction: The accuracy of the convex hull prediction is calculated as the percentage of correctly predicted crystal samples on the hull with respect to the crystal samples on the DFT-calculated hull. In other words, the hull accuracy measures the percentage of predictions on the hull that matches the DFT-calculated samples on the hull. Accordingly, if our model mistakenly predicts a crytal sample to be on the hull while the DFT-calculated sample is above the hull, the hull accuracy measure will not be affected (e.g., see Fig. 4a).