Making LMDBs from ASE traj files
Contents
Making LMDBs from ASE traj files#
This notebook provides an overview of how to create LMDB datasets to be used with the OCP repo. This tutorial is intended for those who wish to use OCP to train on their own datasets. Those interested in just using OCP data need not worry about these steps as they’ve been automated as part of the download script: https://github.com/Open-Catalyst-Project/ocp/blob/master/scripts/download_data.py.
from ocpmodels.preprocessing import AtomsToGraphs
from ocpmodels.datasets import SinglePointLmdbDataset, TrajectoryLmdbDataset
import ase.io
from ase.build import bulk
from ase.build import fcc100, add_adsorbate, molecule
from ase.constraints import FixAtoms
from ase.calculators.emt import EMT
from ase.optimize import BFGS
import matplotlib.pyplot as plt
import lmdb
import pickle
from tqdm import tqdm
import torch
import os
Generate toy dataset: Relaxation of CO on Cu#
adslab = fcc100("Cu", size=(2, 2, 3))
ads = molecule("CO")
add_adsorbate(adslab, ads, 3, offset=(1, 1))
cons = FixAtoms(indices=[atom.index for atom in adslab if (atom.tag == 3)])
adslab.set_constraint(cons)
adslab.center(vacuum=13.0, axis=2)
adslab.set_pbc(True)
adslab.set_calculator(EMT())
dyn = BFGS(adslab, trajectory="CuCO_adslab.traj", logfile=None)
dyn.run(fmax=0, steps=1000)
False
raw_data = ase.io.read("CuCO_adslab.traj", ":")
len(raw_data)
1001
Initial Structure to Relaxed Energy/Structure (IS2RE/IS2RS) LMDBs#
IS2RE/IS2RS LMDBs utilize the SinglePointLmdb dataset. This dataset expects the data to be contained in a SINGLE LMDB file. In addition to the attributes defined by AtomsToGraph, the following attributes must be added for the IS2RE/IS2RS tasks:
pos_relaxed: Relaxed adslab positions
sid: Unique system identifier, arbitrary
y_init: Initial adslab energy, formerly Data.y
y_relaxed: Relaxed adslab energy
tags (optional): 0 - subsurface, 1 - surface, 2 - adsorbate
As a demo, we will use the above generated data to create an IS2R* LMDB file.
Initialize AtomsToGraph feature extractor#
a2g = AtomsToGraphs(
max_neigh=50,
radius=6,
r_energy=True, # False for test data
r_forces=True,
r_distances=False,
r_fixed=True,
)
Initialize LMDB file#
db = lmdb.open(
"sample_CuCO.lmdb",
map_size=1099511627776 * 2,
subdir=False,
meminit=False,
map_async=True,
)
Write data to LMDB#
def read_trajectory_extract_features(a2g, traj_path):
traj = ase.io.read(traj_path, ":")
tags = traj[0].get_tags()
images = [traj[0], traj[-1]]
data_objects = a2g.convert_all(images, disable_tqdm=True)
data_objects[0].tags = torch.LongTensor(tags)
data_objects[1].tags = torch.LongTensor(tags)
return data_objects
system_paths = ["CuCO_adslab.traj"]
idx = 0
for system in system_paths:
# Extract Data object
data_objects = read_trajectory_extract_features(a2g, system)
initial_struc = data_objects[0]
relaxed_struc = data_objects[1]
initial_struc.y_init = initial_struc.y # subtract off reference energy, if applicable
del initial_struc.y
initial_struc.y_relaxed = relaxed_struc.y # subtract off reference energy, if applicable
initial_struc.pos_relaxed = relaxed_struc.pos
# Filter data if necessary
# OCP filters adsorption energies > |10| eV
initial_struc.sid = idx # arbitrary unique identifier
# no neighbor edge case check
if initial_struc.edge_index.shape[1] == 0:
print("no neighbors", traj_path)
continue
# Write to LMDB
txn = db.begin(write=True)
txn.put(f"{idx}".encode("ascii"), pickle.dumps(initial_struc, protocol=-1))
txn.commit()
db.sync()
idx += 1
db.close()
/home/runner/work/ml_catalysis_tutorials/ml_catalysis_tutorials/ocp/ocpmodels/preprocessing/atoms_to_graphs.py:139: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1639180507909/work/torch/csrc/utils/tensor_new.cpp:201.)
cell = torch.Tensor(atoms.get_cell()).view(1, 3, 3)
dataset = SinglePointLmdbDataset({"src": "sample_CuCO.lmdb"})
len(dataset)
/home/runner/micromamba-root/envs/buildenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3433: UserWarning: SinglePointLmdbDataset is deprecated and will be removed in the future.Please use 'LmdbDataset' instead.
exec(code_obj, self.user_global_ns, self.user_ns)
1
dataset[0]
Data(pos=[14, 3], cell=[1, 3, 3], atomic_numbers=[14], natoms=14, tags=[14], edge_index=[2, 636], cell_offsets=[636, 3], force=[14, 3], fixed=[14], y_init=3.989314410668539, y_relaxed=3.9683558933956853, pos_relaxed=[14, 3], sid=0)
Structure to Energy and Forces (S2EF) LMDBs#
S2EF LMDBs utilize the TrajectoryLmdb dataset. This dataset expects a directory of LMDB files. In addition to the attributes defined by AtomsToGraph, the following attributes must be added for the S2EF task:
tags (optional): 0 - subsurface, 1 - surface, 2 - adsorbate
fid: Frame index along the trajcetory
sid- sid: Unique system identifier, arbitrary
Additionally, a “length” key must be added to each LMDB file.
As a demo, we will use the above generated data to create an S2EF LMDB dataset
os.makedirs("s2ef", exist_ok=True)
db = lmdb.open(
"s2ef/sample_CuCO.lmdb",
map_size=1099511627776 * 2,
subdir=False,
meminit=False,
map_async=True,
)
tags = raw_data[0].get_tags()
data_objects = a2g.convert_all(raw_data, disable_tqdm=True)
for fid, data in tqdm(enumerate(data_objects), total=len(data_objects)):
#assign sid
data.sid = torch.LongTensor([0])
#assign fid
data.fid = torch.LongTensor([fid])
#assign tags, if available
data.tags = torch.LongTensor(tags)
# Filter data if necessary
# OCP filters adsorption energies > |10| eV and forces > |50| eV/A
# no neighbor edge case check
if data.edge_index.shape[1] == 0:
print("no neighbors", traj_path)
continue
txn = db.begin(write=True)
txn.put(f"{fid}".encode("ascii"), pickle.dumps(data, protocol=-1))
txn.commit()
txn = db.begin(write=True)
txn.put(f"length".encode("ascii"), pickle.dumps(len(data_objects), protocol=-1))
txn.commit()
db.sync()
db.close()
0%| | 0/1001 [00:00<?, ?it/s]
8%|▊ | 85/1001 [00:00<00:01, 848.02it/s]
17%|█▋ | 170/1001 [00:00<00:01, 825.54it/s]
25%|██▌ | 253/1001 [00:00<00:00, 817.57it/s]
33%|███▎ | 335/1001 [00:00<00:00, 805.97it/s]
42%|████▏ | 416/1001 [00:00<00:00, 806.37it/s]
50%|████▉ | 498/1001 [00:00<00:00, 808.16it/s]
58%|█████▊ | 582/1001 [00:00<00:00, 816.64it/s]
66%|██████▋ | 664/1001 [00:00<00:00, 812.64it/s]
75%|███████▍ | 746/1001 [00:00<00:00, 802.73it/s]
83%|████████▎ | 827/1001 [00:01<00:00, 797.37it/s]
91%|█████████ | 910/1001 [00:01<00:00, 806.24it/s]
99%|█████████▉| 991/1001 [00:01<00:00, 804.76it/s]
100%|██████████| 1001/1001 [00:01<00:00, 808.04it/s]
dataset = TrajectoryLmdbDataset({"src": "s2ef/"})
len(dataset)
/home/runner/micromamba-root/envs/buildenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3433: UserWarning: TrajectoryLmdbDataset is deprecated and will be removed in the future.Please use 'LmdbDataset' instead.
exec(code_obj, self.user_global_ns, self.user_ns)
1001
dataset[0]
Data(pos=[14, 3], cell=[1, 3, 3], atomic_numbers=[14], natoms=14, tags=[14], edge_index=[2, 636], cell_offsets=[636, 3], y=3.989314410668539, force=[14, 3], fixed=[14], sid=[1], fid=[1], id='0_0')
Advanced usage#
TrajectoryLmdbDataset supports multiple LMDB files because the need to highly parallelize the dataset construction process. With OCP’s largest split containing 135M+ frames, the need to parallelize the LMDB generation process for these was necessary. If you find yourself needing to deal with very large datasets we recommend parallelizing this process.
Interacting with the LMDBs#
Below we demonstrate how to interact with an LMDB to extract particular information.
dataset = TrajectoryLmdbDataset({"src": "s2ef/"})
data = dataset[0]
data
Data(pos=[14, 3], cell=[1, 3, 3], atomic_numbers=[14], natoms=14, tags=[14], edge_index=[2, 636], cell_offsets=[636, 3], y=3.989314410668539, force=[14, 3], fixed=[14], sid=[1], fid=[1], id='0_0')
energies = torch.tensor([data.y for data in dataset])
energies
tensor([3.9893, 3.9835, 3.9784, ..., 3.9684, 3.9684, 3.9684])
plt.hist(energies, bins = 10)
plt.yscale("log")
plt.xlabel("Energies")
plt.show()
