import numpy as np
import tensorflow.keras as ks
import tensorflow as tf
from kgcnn.literature.Schnet import make_model
from kgcnn.utils.data import ragged_tensor_from_nested_numpy
[docs]class SchnetEnergy(ks.Model):
"""Subclassed SchNet which outputs energies from coordinates.
The model is supposed to be saved and exported.
"""
def __init__(self,
model_module="schnet_e",
schnet_kwargs=None,
**kwargs):
super(SchnetEnergy, self).__init__(**kwargs)
self.schnet_kwargs = schnet_kwargs
self.model_module = model_module
self._schnet_model = make_model(**schnet_kwargs)
# Build the model with example data.
self.predict([tf.ragged.constant([[0]]),
tf.ragged.constant([[[0.0, 0.0, 0.0]]], ragged_rank=1, inner_shape=(3,)),
tf.ragged.constant([[[0, 0]]], ragged_rank=1, inner_shape=(2,))
])
[docs] def call(self, data, training=False, **kwargs):
"""Call the model output, forward pass.
Args:
data (list): Atoms, coordinates, indices.
training (bool, optional): Training Mode. Defaults to False.
Returns:
y (tf.tensor): predicted Energy.
"""
x = data
out = self._schnet_model(x)
return out
[docs] def get_config(self):
# conf = super(NACModel2, self).get_config()
conf = {}
conf.update({
"model_module": self.model_module,
"schnet_kwargs": self.schnet_kwargs,
})
return conf
[docs] def call_to_numpy_output(self, y):
if isinstance(y, np.ndarray):
return y
return y.numpy()