Compositional Based Property Prediction of Perovskites with CrabNet
This notebook demonstrates how to use CrabNet to predict the bandgap of perovskites using data from the Perovskite Database in NOMAD for training. The method can be extended to predict other properties of perovskites and can also be combined with other methodologies available in the community for this task. To explore more of these methodologies, we recommend taking a look at MatBench, a benchmarking suite for these tasks in the materials informatics community.
Installations¶
We will start by running a couple of pip installers. Skip this part if you have the libraries installed in your environment.
! pip install torch
! pip install crabnet
! pip install pandas
Retrieve Data using NOMAD API¶
We will now fetch Perovskite solar cell data from the NOMAD API. If you already have the data, the perovskite_bandgap_devices.csv in the data folder, you can skip this part, and continue with Loading Data. Note that calling the database through the API may take a while.
from time import monotonic
import jmespath
import requests
base_url = 'https://nomad-lab.eu/prod/v1/api/v1'
bandgaps = []
reduced_formulas = []
descriptive_formulas = []
page_after_value = None
def extract_values(entry):
bandgaps.append(
jmespath.search(
'results.properties.electronic.band_structure_electronic[0].band_gap[0].value',
entry,
)
)
reduced_formulas.append(
jmespath.search('results.material.chemical_formula_reduced', entry)
)
descriptive_formulas.append(
jmespath.search('results.material.chemical_formula_descriptive', entry)
)
start = monotonic()
while True:
response = requests.post(
f'{base_url}/entries/query',
json={
'owner': 'visible',
'query': {
'and': [
# {"results.material.elements:all": ["Sn"]},
{'sections:all': ['nomad.datamodel.results.SolarCell']}
]
},
'pagination': {'page_size': 1000, 'page_after_value': page_after_value},
},
)
response_code = response.status_code
data = response.json()
pagination = data['pagination']
if page_after_value is None:
print(f'Total number of entries: {pagination["total"]}')
print(response_code)
page_after_value = data['pagination'].get('next_page_after_value')
for entry in data['data']:
extract_values(entry)
if not page_after_value:
break
end = monotonic()
print(f'Query took {end - start:.2f} seconds')
Total number of entries: 43108 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 Query took 1230.80 seconds
You can verify that the data was fetched correctly by checking the length of the lists, e.g., bandgap
Save Data¶
Put Data into a Pandas DataFrame and Save¶
We can also have a quick look on the DataFrame, and change the units of the bandgap from J to eV.
import os
import pandas as pd
df = pd.DataFrame(
{
'reduced_formulas': reduced_formulas,
'descriptive_formulas': descriptive_formulas,
'bandgap': bandgaps,
}
)
df['bandgap'] = pd.to_numeric(df['bandgap'], errors='coerce')
df['bandgap'] = df['bandgap'] * 6.24150974e18
df.head()
if not os.path.exists('data'):
os.makedirs('data')
df.to_csv('data/perovskite_bandgap_devices.csv', index=False)
Load data¶
In the previous part, we retrieved perovskite solar cell data using the NOMAD API, converted it to a Pandas DataFrame, and saved it as perovskite_bandgap_devices.csv in the data folder. If you already have the data, you can start from this section, where we will import and clean it.
The distribution of the chemical diversity of the dataset can be viewed in the dynamic periodic table of the NOMAD entries explorer. The down right corner of every element shows the number of entries (solar cells) that contain in the absorber a given element. It becomes obvious that the data set is imbalanced with the majority of the entries being Pb based, containing C, N and H (hybrid perovskites) and mostly halide compounds, with most of the entries having I and/or Br.
import pandas as pd
df = pd.read_csv('data/perovskite_bandgap_devices.csv')
df.head()
| reduced_formulas | descriptive_formulas | bandgap | |
|---|---|---|---|
| 0 | CH6I3NPb | MAPbI3 | 1.6 |
| 1 | CH6I3NPb | MAPbI3 | NaN |
| 2 | CH6I3NPb | MAPbI3 | 1.5 |
| 3 | CH6I3NPb | MAPbI3 | 1.6 |
| 4 | CH6I3NPb | MAPbI3 | 1.6 |
EDA & Data Cleaning¶
Exploratory Data Analysis (EDA) is a useful step in understanding and preparing datasets for modeling by summarizing data, checking for anomalies, finding patterns and relationships.
Insights from NOMAD GUI¶
The distribution of chemical diversity of the dataset can be explored using the periodic table in the NOMAD solar cell app. The bottom-right corner of each element displays the number of entries (solar cells) that include the element in the absorber. The dataset is notably imbalanced, with the majority of entries being Pb-based, containing C, N, H (hybrid perovskites), and halides like I and Br.
Let's have a look on the statistical summary of the dataset:
df.describe()
We can see that the dataset contains a diverse range of bandgap values (1.16 to 3.05 eV). Some parameters such as open circuit voltage (voc) and fill factor (ff), include extreme or zero values, which might need attention for data cleaning.
Remove NaNs¶
The dataset might include missing values, let's check if there are any:
df.isna().sum()
Let's now keep only the part of the dataframe, where the value for bandgap is not Na.
df = df[df['bandgap'].notna()]
df = df[df['reduced_formulas'].notna()]
df.isna().sum()
Group Repeated Formulas¶
Even before checking the dataset, we can already guess that are many repeated formulas and bandgap values in the DataFrame. But let's verify it:
df['reduced_formulas'].value_counts().head(10)
The dataset contains many repeated formulas and corresponding bandgap values. To make sure each formula is unique, we will use the groupby_formula function from CrabNet to group entries by their formulas and use the mean values for their bandgap. We make a new final DataFrame for our model, call it df_reduced_formula and adjust colum names etc.
Let's rename columns and use the gourpby_formula function from CrabNet:
from crabnet.utils.data import groupby_formula # type: ignore
# Rename the column 'bandgap' to 'target', and 'reduced_formula' to 'formula'
df.rename(columns={'bandgap': 'target'}, inplace=True)
df.rename(columns={'reduced_formulas': 'formula'}, inplace=True)
# Group repeated formulas and take the mean of the target
df_grouped_formula = groupby_formula(df, how='mean')
df_grouped_formula.head()
Let's check the shape of the DataFrame we would like to use for our ML:
df_grouped_formula.shape
Check Element Prevalence¶
We will use pymatviz (link to docs) for visualizing the element prevalence.
Run the following snippet if the package is not installed in the environment yet, otherwise skip it.
You can simply try:
! pip install pymatviz
in case of problems, try it from its developer repo:
! pip install git+https://github.com/janosh/pymatviz
Once installed, we import and use it:
from pymatviz import ptable_heatmap_plotly # type: ignore
ptable_heatmap_plotly(
df_grouped_formula['formula'],
log=True,
colorscale='BuPu',
font_colors='black',
fmt='.3g',
colorbar=dict(orientation='v', title='Element Prevalence'),
)
As you can see the current data is heavily based on hybrid halide perovskites, so we expect the model to perform better when predicting these materials. Let's continue and build the model in the next section!
Build and Fit the Model¶
We randomize the dataset and split it into training, validation, and test sets in a ratio of 80%, 10%, and 10%, respectively.
import numpy as np
train_df, val_df, test_df = np.split(
df_grouped_formula.sample(frac=1, random_state=42),
[int(0.8 * len(df_grouped_formula)), int(0.9 * len(df_grouped_formula))],
)
We then fit the model using the CrabNet implementation.
from crabnet.crabnet_ import CrabNet # type: ignore
crabnet_bandgap = CrabNet(
mat_prop='bandgap',
model_name='perovskite_bg_prediction',
elem_prop='mat2vec',
learningcurve=True,
)
crabnet_bandgap.fit(train_df, val_df)
Model evaluation¶
After training, we evaluate the model using the validation data.
from crabnet.utils.figures import act_pred # type: ignore
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
# Train data
train_df_zeros = pd.DataFrame(
{'formula': train_df['formula'], 'target': [0.0] * len(train_df['formula'])}
)
train_df_predicted, train_df_predicted_sigma = crabnet_bandgap.predict(
train_df_zeros, return_uncertainty=True
)
act_pred(train_df['target'], train_df_predicted)
r2 = r2_score(train_df['target'], train_df_predicted)
print(f'R2 score: {r2}')
mse = mean_squared_error(train_df['target'], train_df_predicted)
print(f'MSE: {mse}')
mae = mean_absolute_error(train_df['target'], train_df_predicted)
print(f'MAE: {mae} eV')
for validation data we have:
# Validation data
val_df_zeros = pd.DataFrame(
{'formula': val_df['formula'], 'target': [0.0] * len(val_df['formula'])}
)
val_df_predicted, val_df_predicted_sigma = crabnet_bandgap.predict(
val_df_zeros, return_uncertainty=True
)
act_pred(val_df['target'], val_df_predicted)
r2 = r2_score(val_df['target'], val_df_predicted)
print(f'R2 score: {r2}')
mse = mean_squared_error(val_df['target'], val_df_predicted)
print(f'MSE: {mse}')
mae = mean_absolute_error(val_df['target'], val_df_predicted)
print(f'MAE: {mae} eV')
and finally for test data:
# Test data
test_df_zeros = pd.DataFrame(
{'formula': test_df['formula'], 'target': [0.0] * len(test_df['formula'])}
)
test_df_predicted, test_df_predicted_sigma = crabnet_bandgap.predict(
test_df_zeros, return_uncertainty=True
)
act_pred(test_df['target'], test_df_predicted)
r2 = r2_score(test_df['target'], test_df_predicted)
print(f'R2 score: {r2}')
mse = mean_squared_error(test_df['target'], test_df_predicted)
print(f'MSE: {mse}')
mae = mean_absolute_error(test_df['target'], test_df_predicted)
print(f'MAE: {mae} eV')
Predict Bandgap from Individual Formulas¶
Now we are ready to run some predictions using our trainned model. We will start loading the model just in case you want to start direcly here in a new session and the model weights are available.
import numpy as np
import pandas as pd # only if you jump to this cell directly
from crabnet.crabnet_ import CrabNet # type: ignore
from crabnet.kingcrab import SubCrab # type: ignore
# Instantiate SubCrab
sub_crab_model = SubCrab()
# Instantiate CrabNet and set its model to SubCrab
crabnet_model = CrabNet()
crabnet_model.model = sub_crab_model
# Load the pre-trained network
file_path = r'perovskite_bg_prediction.pth'
crabnet_model.load_network(file_path)
Then define a function and run it for predicting the bandgap from individual formulas:
# Function to predict the bandgap of a given formula
def predict_bandgap(formula):
input_df = pd.DataFrame({'formula': [formula], 'target': [0.0]})
prediction, prediction_sigma = crabnet_bandgap.predict(
input_df, return_uncertainty=True
)
return prediction, prediction_sigma
# Main script to take user input and display predictions
while True:
formula = input(
"Enter a formula (e.g., CsPbBr3, CH3NH3PbI3) or type 'exit' to quit: "
)
if formula.lower() == 'exit':
print('Exiting prediction tool. Goodbye!')
break
try:
prediction, prediction_sigma = predict_bandgap(formula)
print(
f'Predicted bandgap: {np.round(prediction[0], 3)} +/- {np.round(prediction_sigma[0], 3)} eV'
)
except Exception as e:
print(f'Error during prediction: {e}')
Alternatively, this interactive widget allows you to input a chemical formula, predict its bandgap using the trained model, and check if the formula exists in the dataset. If it does, the widget displays the average bandgap value used during training.
import numpy as np
import pandas as pd
from IPython.display import display
from ipywidgets import Button, HBox, Output, Text, VBox
# Function to predict the bandgap of a given formula
def predict_bandgap(formula):
val_df = pd.DataFrame({'formula': [formula], 'target': [0.0]})
prediction, prediction_sigma = crabnet_bandgap.predict(
val_df, return_uncertainty=True
)
return prediction, prediction_sigma
# Function to check if the formula exists in the dataset
def check_formula_in_dataset(formula):
if formula in df_grouped_formula['formula'].values:
avg_bandgap = df_grouped_formula.loc[
df_grouped_formula['formula'] == formula, 'target'
].values[0]
return avg_bandgap
else:
return None
# Setting up the widget interface
formula_input = Text(
value='',
placeholder='Enter formula (e.g., CsPbBr3)',
description='Formula:',
)
predict_button = Button(description='Predict Bandgap', button_style='success')
output = Output()
def on_click(b):
with output:
output.clear_output()
try:
formula = formula_input.value.strip()
if not formula:
print('Please enter a valid chemical formula.')
return
# Prediction
prediction, sigma = predict_bandgap(formula)
print(
f'Predicted Bandgap: {np.round(prediction[0], 3)} ± {np.round(sigma[0], 3)} eV'
)
# Dataset check
avg_bandgap = check_formula_in_dataset(formula)
if avg_bandgap is not None:
print(
f'The averaged literature bandgap for {formula} is {avg_bandgap:.3f} eV (from dataset).'
)
else:
print(f"The formula '{formula}' is not contained in the dataset.")
except Exception as e:
print(f'Error: {e}')
predict_button.on_click(on_click)
display(VBox([HBox([formula_input, predict_button]), output]))