Initial Structure to Relaxed Energy (IS2RE) #

The IS2RE task predicts the relaxed energy (energy of the relaxed state) given the initial state of a system. One approach to this is by training a regression model mapping the initial structure to the relaxed energy. We call this the direct approach to the IS2RE task.

An alternative is to perform a structure relaxation using an S2EF model to obtain the relaxed state and compute the energy of that state (see the IS2RS task below for details about relaxation).

Steps for training an IS2RE model#

  1. Define or load a configuration (config), which includes the following

  • task

  • model

  • optimizer

  • dataset

  • trainer

  1. Create an EnergyTrainer object

  2. Train the model

  3. Validate the model

Imports#

from ocpmodels.trainers import EnergyTrainer
from ocpmodels.datasets import SinglePointLmdbDataset
from ocpmodels import models
from ocpmodels.common import logger
from ocpmodels.common.utils import setup_logging
setup_logging()

import numpy as np
import copy
import os

Dataset#

%%bash
mkdir data
cd data
wget -q -nc http://dl.fbaipublicfiles.com/opencatalystproject/data/tutorial_data.tar.gz -O tutorial_data.tar.gz
tar -xzvf tutorial_data.tar.gz
./
./is2re/
./is2re/train_100/
./is2re/train_100/data.lmdb
./is2re/train_100/data.lmdb-lock
./is2re/val_20/
./is2re/val_20/data.lmdb
./is2re/val_20/data.lmdb-lock
./s2ef/
./s2ef/train_100/
./s2ef/train_100/data.lmdb
./s2ef/train_100/data.lmdb-lock
./s2ef/val_20/
./s2ef/val_20/data.lmdb
./s2ef/val_20/data.lmdb-lock
train_src = "data/is2re/train_100/data.lmdb"
val_src = "data/is2re/val_20/data.lmdb"

Normalize data#

If you wish to normalize the targets we must compute the mean and standard deviation for our energy values.

train_dataset = SinglePointLmdbDataset({"src": train_src})

energies = []
for data in train_dataset:
  energies.append(data.y_relaxed)

mean = np.mean(energies)
stdev = np.std(energies)
/home/runner/micromamba-root/envs/buildenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py:3433: UserWarning: SinglePointLmdbDataset is deprecated and will be removed in the future.Please use 'LmdbDataset' instead.
  exec(code_obj, self.user_global_ns, self.user_ns)

Define the Config#

For this example, we will explicitly define the config; however, a set of default configs can be found here. Default config yaml files can easily be loaded with the following utility. Loading a yaml config is preferrable when launching jobs from the command line. We have included our best models’ config files here for reference.

Note - we only train for a single epoch with a reduced batch size (GPU memory constraints) for demonstration purposes, modify accordingly for full convergence.

# Task
task = {
  "dataset": "single_point_lmdb",
  "description": "Relaxed state energy prediction from initial structure.",
  "type": "regression",
  "metric": "mae",
  "labels": ["relaxed energy"],
}
# Model
model = {
    'name': 'gemnet_t',
    "num_spherical": 7,
    "num_radial": 64,
    "num_blocks": 5,
    "emb_size_atom": 256,
    "emb_size_edge": 512,
    "emb_size_trip": 64,
    "emb_size_rbf": 16,
    "emb_size_cbf": 16,
    "emb_size_bil_trip": 64,
    "num_before_skip": 1,
    "num_after_skip": 2,
    "num_concat": 1,
    "num_atom": 3,
    "cutoff": 6.0,
    "max_neighbors": 50,
    "rbf": {"name": "gaussian"},
    "envelope": {
      "name": "polynomial",
      "exponent": 5,
    },
    "cbf": {"name": "spherical_harmonics"},
    "extensive": True,
    "otf_graph": False,
    "output_init": "HeOrthogonal",
    "activation": "silu",
    "scale_file": "configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json",
    "regress_forces": False,
    "direct_forces": False,
}
# Optimizer
optimizer = {
    'batch_size': 1,         # originally 32
    'eval_batch_size': 1,    # originally 32
    'num_workers': 2,
    'lr_initial': 1.e-4,
    'optimizer': 'AdamW',
    'optimizer_params': {"amsgrad": True},
    'scheduler': "ReduceLROnPlateau",
    'mode': "min",
    'factor': 0.8,
    'patience': 3,
    'max_epochs': 1,         # used for demonstration purposes
    'ema_decay': 0.999,
    'clip_grad_norm': 10,
    'loss_energy': 'mae',
}
# Dataset
dataset = [
  {'src': train_src,
   'normalize_labels': True,
   'target_mean': mean,
   'target_std': stdev,
  }, # train set 
  {'src': val_src}, # val set (optional)
]

###Create EnergyTrainer

energy_trainer = EnergyTrainer(
    task=task,
    model=copy.deepcopy(model), # copied for later use, not necessary in practice.
    dataset=dataset,
    optimizer=optimizer,
    identifier="IS2RE-example",
    run_dir="./", # directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
    is_debug=False, # if True, do not save checkpoint, logs, or results
    print_every=5,
    seed=0, # random seed to use
    logger="tensorboard", # logger of choice (tensorboard and wandb supported)
    local_rank=0,
    amp=True, # use PyTorch Automatic Mixed Precision (faster training and less memory usage)    
)
amp: true
cmd:
  checkpoint_dir: ./checkpoints/2022-10-31-18-01-36-IS2RE-example
  commit: cba9fb6
  identifier: IS2RE-example
  logs_dir: ./logs/tensorboard/2022-10-31-18-01-36-IS2RE-example
  print_every: 5
  results_dir: ./results/2022-10-31-18-01-36-IS2RE-example
  seed: 0
  timestamp_id: 2022-10-31-18-01-36-IS2RE-example
dataset:
  normalize_labels: true
  src: data/is2re/train_100/data.lmdb
  target_mean: !!python/object/apply:numpy.core.multiarray.scalar
  - &id001 !!python/object/apply:numpy.dtype
    args:
    - f8
    - false
    - true
    state: !!python/tuple
    - 3
    - <
    - null
    - null
    - null
    - -1
    - -1
    - 0
  - !!binary |
    MjyJzgpQ978=
  target_std: !!python/object/apply:numpy.core.multiarray.scalar
  - *id001
  - !!binary |
    PnyyzMtk/T8=
gpus: 0
logger: tensorboard
model: gemnet_t
model_attributes:
  activation: silu
  cbf:
    name: spherical_harmonics
  cutoff: 6.0
  direct_forces: false
  emb_size_atom: 256
  emb_size_bil_trip: 64
  emb_size_cbf: 16
  emb_size_edge: 512
  emb_size_rbf: 16
  emb_size_trip: 64
  envelope:
    exponent: 5
    name: polynomial
  extensive: true
  max_neighbors: 50
  num_after_skip: 2
  num_atom: 3
  num_before_skip: 1
  num_blocks: 5
  num_concat: 1
  num_radial: 64
  num_spherical: 7
  otf_graph: false
  output_init: HeOrthogonal
  rbf:
    name: gaussian
  regress_forces: false
  scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json
noddp: false
optim:
  batch_size: 1
  clip_grad_norm: 10
  ema_decay: 0.999
  eval_batch_size: 1
  factor: 0.8
  loss_energy: mae
  lr_initial: 0.0001
  max_epochs: 1
  mode: min
  num_workers: 2
  optimizer: AdamW
  optimizer_params:
    amsgrad: true
  patience: 3
  scheduler: ReduceLROnPlateau
slurm: {}
task:
  dataset: single_point_lmdb
  description: Relaxed state energy prediction from initial structure.
  labels:
  - relaxed energy
  metric: mae
  type: regression
trainer: energy
val_dataset:
  src: data/is2re/val_20/data.lmdb

2022-10-31 18:01:06 (INFO): Batch balancing is disabled for single GPU training.
2022-10-31 18:01:06 (INFO): Batch balancing is disabled for single GPU training.
2022-10-31 18:01:06 (INFO): Loading dataset: single_point_lmdb
2022-10-31 18:01:06 (INFO): Loading model: gemnet_t
/home/runner/micromamba-root/envs/buildenv/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py:115: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
  warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [6], line 1
----> 1 energy_trainer = EnergyTrainer(
      2     task=task,
      3     model=copy.deepcopy(model), # copied for later use, not necessary in practice.
      4     dataset=dataset,
      5     optimizer=optimizer,
      6     identifier="IS2RE-example",
      7     run_dir="./", # directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
      8     is_debug=False, # if True, do not save checkpoint, logs, or results
      9     print_every=5,
     10     seed=0, # random seed to use
     11     logger="tensorboard", # logger of choice (tensorboard and wandb supported)
     12     local_rank=0,
     13     amp=True, # use PyTorch Automatic Mixed Precision (faster training and less memory usage)    
     14 )

File ~/work/ml_catalysis_tutorials/ml_catalysis_tutorials/ocp/ocpmodels/trainers/energy_trainer.py:78, in EnergyTrainer.__init__(self, task, model, dataset, optimizer, identifier, normalizer, timestamp_id, run_dir, is_debug, is_hpo, print_every, seed, logger, local_rank, amp, cpu, slurm, noddp)
     57 def __init__(
     58     self,
     59     task,
   (...)
     76     noddp=False,
     77 ):
---> 78     super().__init__(
     79         task=task,
     80         model=model,
     81         dataset=dataset,
     82         optimizer=optimizer,
     83         identifier=identifier,
     84         normalizer=normalizer,
     85         timestamp_id=timestamp_id,
     86         run_dir=run_dir,
     87         is_debug=is_debug,
     88         is_hpo=is_hpo,
     89         print_every=print_every,
     90         seed=seed,
     91         logger=logger,
     92         local_rank=local_rank,
     93         amp=amp,
     94         cpu=cpu,
     95         name="is2re",
     96         slurm=slurm,
     97         noddp=noddp,
     98     )

File ~/work/ml_catalysis_tutorials/ml_catalysis_tutorials/ocp/ocpmodels/trainers/base_trainer.py:205, in BaseTrainer.__init__(self, task, model, dataset, optimizer, identifier, normalizer, timestamp_id, run_dir, is_debug, is_hpo, print_every, seed, logger, local_rank, amp, cpu, name, slurm, noddp)
    203 if distutils.is_master():
    204     print(yaml.dump(self.config, default_flow_style=False))
--> 205 self.load()
    207 self.evaluator = Evaluator(task=name)

File ~/work/ml_catalysis_tutorials/ml_catalysis_tutorials/ocp/ocpmodels/trainers/base_trainer.py:214, in BaseTrainer.load(self)
    212 self.load_datasets()
    213 self.load_task()
--> 214 self.load_model()
    215 self.load_loss()
    216 self.load_optimizer()

File ~/work/ml_catalysis_tutorials/ml_catalysis_tutorials/ocp/ocpmodels/trainers/base_trainer.py:369, in BaseTrainer.load_model(self)
    364 bond_feat_dim = self.config["model_attributes"].get(
    365     "num_gaussians", 50
    366 )
    368 loader = self.train_loader or self.val_loader or self.test_loader
--> 369 self.model = registry.get_model_class(self.config["model"])(
    370     loader.dataset[0].x.shape[-1]
    371     if loader
    372     and hasattr(loader.dataset[0], "x")
    373     and loader.dataset[0].x is not None
    374     else None,
    375     bond_feat_dim,
    376     self.num_targets,
    377     **self.config["model_attributes"],
    378 ).to(self.device)
    380 if distutils.is_master():
    381     logging.info(
    382         f"Loaded {self.model.__class__.__name__} with "
    383         f"{self.model.num_params} parameters."
    384     )

File ~/work/ml_catalysis_tutorials/ml_catalysis_tutorials/ocp/ocpmodels/models/gemnet/gemnet.py:261, in GemNetT.__init__(self, num_atoms, bond_feat_dim, num_targets, num_spherical, num_radial, num_blocks, emb_size_atom, emb_size_edge, emb_size_trip, emb_size_rbf, emb_size_cbf, emb_size_bil_trip, num_before_skip, num_after_skip, num_concat, num_atom, regress_forces, direct_forces, cutoff, max_neighbors, rbf, envelope, cbf, extensive, otf_graph, use_pbc, output_init, activation, num_elements, scale_file)
    252 self.int_blocks = torch.nn.ModuleList(int_blocks)
    254 self.shared_parameters = [
    255     (self.mlp_rbf3.linear.weight, self.num_blocks),
    256     (self.mlp_cbf3.weight, self.num_blocks),
    257     (self.mlp_rbf_h.linear.weight, self.num_blocks),
    258     (self.mlp_rbf_out.linear.weight, self.num_blocks + 1),
    259 ]
--> 261 load_scales_compat(self, scale_file)

File ~/work/ml_catalysis_tutorials/ml_catalysis_tutorials/ocp/ocpmodels/modules/scaling/compat.py:55, in load_scales_compat(module, scale_file)
     52 def load_scales_compat(
     53     module: nn.Module, scale_file: Optional[Union[str, ScaleDict]]
     54 ):
---> 55     scale_dict = _load_scale_dict(scale_file)
     56     if not scale_dict:
     57         return

File ~/work/ml_catalysis_tutorials/ml_catalysis_tutorials/ocp/ocpmodels/modules/scaling/compat.py:31, in _load_scale_dict(scale_file)
     29 path = Path(scale_file)
     30 if not path.exists():
---> 31     raise ValueError(f"Scale file {path} does not exist.")
     33 scale_dict: Optional[ScaleDict] = None
     34 if path.suffix == ".pt":

ValueError: Scale file configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json does not exist.
energy_trainer.model

Train the Model#

energy_trainer.train()

Validate the Model#

Load the best checkpoint#

# The `best_checpoint.pt` file contains the checkpoint with the best val performance
checkpoint_path = os.path.join(energy_trainer.config["cmd"]["checkpoint_dir"], "best_checkpoint.pt")
checkpoint_path
# Append the dataset with the test set. We use the same val set for demonstration.

# Dataset
dataset.append(
  {'src': val_src}, # test set (optional)
)
dataset
pretrained_energy_trainer = EnergyTrainer(
    task=task,
    model=model,
    dataset=dataset,
    optimizer=optimizer,
    identifier="IS2RE-val-example",
    run_dir="./", # directory to save results if is_debug=False. Prediction files are saved here so be careful not to override!
    is_debug=False, # if True, do not save checkpoint, logs, or results
    print_every=10,
    seed=0, # random seed to use
    logger="tensorboard", # logger of choice (tensorboard and wandb supported)
    local_rank=0,
    amp=True, # use PyTorch Automatic Mixed Precision (faster training and less memory usage)
)

pretrained_energy_trainer.load_checkpoint(checkpoint_path=checkpoint_path)

Test the model#

# make predictions on the existing test_loader
predictions = pretrained_energy_trainer.predict(pretrained_trainer.test_loader, results_file="is2re_results", disable_tqdm=False)
energies = predictions["energy"]