import torch
from typing import Dict, Any, Optional, List
import os
from biobb_pytorch.mdae.utils.log_utils import get_size
from biobb_common.tools.file_utils import launchlogger
from biobb_common.tools import file_utils as fu
from biobb_common.generic.biobb_object import BiobbObject
[docs]
class GeneratePlumed(BiobbObject):
"""
| biobb_plumed GeneratePlumed
| Generate PLUMED input for biased dynamics using an MDAE model.
| Generates a PLUMED input file, features.dat, and converts the model to .ptc format.
Args:
input_model_pth_path (str): Path to the trained PyTorch model (.pth) to be converted to TorchScript and used in PLUMED. File type: input. Accepted formats: pth (edam:format_2333).
input_stats_pt_path (str) (Optional): Path to statistics file (.pt) produced during featurization, used to derive the PLUMED features.dat content. File type: input. Accepted formats: pt (edam:format_2333).
input_reference_pdb_path (str) (Optional): Path to reference PDB used for FIT_TO_TEMPLATE actions when Cartesian features are present. File type: input. Accepted formats: pdb (edam:format_1476).
input_ndx_path (str) (Optional): Path to GROMACS index (NDX) file used to define groups when required by PLUMED. File type: input. Accepted formats: ndx (edam:format_2033).
output_plumed_dat_path (str): Path to the output PLUMED input file. File type: output. Accepted formats: dat (edam:format_2330).
output_features_dat_path (str): Path to the output features.dat file describing the CVs to PLUMED. File type: output. Accepted formats: dat (edam:format_2330).
output_model_ptc_path (str): Path to the output TorchScript model file (.ptc) for PLUMED's PYTORCH_MODEL action. File type: output. Accepted formats: ptc (edam:format_2333).
properties (dict - Python dictionary object containing the tool parameters, not input/output files):
* **include_energy** (*bool*) - (True) Whether to include ENERGY in PLUMED.
* **bias** (*list*) - ([]) List of biasing actions (e.g. METAD) to be added to the PLUMED file.
* **prints** (*dict*) - ({"ARG": "*", "STRIDE": 1, "FILE": "COLVAR"}) PRINT command parameters (e.g. ARG, STRIDE, FILE).
* **group** (*dict*) - (None) GROUP definition options (label, NDX group or atom selection parameters).
* **wholemolecules** (*dict*) - (None) WHOLEMOLECULES options when using Cartesian coordinates.
* **fit_to_template** (*dict*) - (None) FIT_TO_TEMPLATE options (e.g. STRIDE, TYPE, etc.).
* **pytorch_model** (*dict*) - (None) PYTORCH_MODEL options (label, PACE and other parameters).
Examples:
This example shows how to use the GeneratePlumed class to generate a PLUMED input file for biased dynamics using an MDAE model::
from biobb_plumed.generate_plumed import make_plumed
prop = {
"additional_actions": [
{
"name": "ENERGY",
"label": "ene"
},
{
"name": "RMSD",
"label": "rmsd",
"params": {
"TYPE": "OPTIMAL"
}
}
],
"group": {
"label": "c_alphas",
"NDX_GROUP": "chA_&_C-alpha"
},
"wholemolecules": {
"ENTITY0": "c_alphas"
},
"fit_to_template": {
"STRIDE": 1,
"TYPE": "OPTIMAL"
},
"pytorch_model": {
"label": "cv",
"PACE": 1
},
"bias": [
{
"name": "METAD",
"label": "bias",
"params": {
"ARG": "cv.1",
"PACE": 500,
"HEIGHT": 1.2,
"SIGMA": 0.35,
"FILE": "HILLS",
"BIASFACTOR": 8
}
}
],
"prints": {
"ARG": "cv.*,bias.*",
"STRIDE": 1,
"FILE": "COLVAR"
}
}
make_plumed(
input_model_pth_path="model.pth",
input_stats_pt_path="stats.pt",
output_plumed_dat_path="plumed.dat",
output_features_dat_path="features.dat",
output_model_ptc_path="model.ptc",
properties=prop
)
Info:
* wrapped_software:
* name: PLUMED with PyTorch
* version: >=2.0
* license: LGPL 3.0
* ontology:
* name: EDAM
* schema: http://edamontology.org/EDAM.owl
"""
def __init__(
self,
input_model_pth_path: str,
input_stats_pt_path: Optional[str] = None,
input_reference_pdb_path: Optional[str] = None,
input_ndx_path: Optional[str] = None,
output_plumed_dat_path: str = 'plumed.dat',
output_features_dat_path: str = 'features.dat',
output_model_ptc_path: str = 'model.ptc',
properties: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
properties = properties or {}
super().__init__(properties)
self.locals_var_dict = locals().copy()
# Input/Output files
self.io_dict = {
"in": {"input_model_pth_path": input_model_pth_path},
"out": {
"output_plumed_dat_path": output_plumed_dat_path,
"output_features_dat_path": output_features_dat_path,
"output_model_ptc_path": output_model_ptc_path
}
}
if input_stats_pt_path:
self.io_dict["in"]["input_stats_pt_path"] = input_stats_pt_path
if input_reference_pdb_path:
self.io_dict["in"]["input_reference_pdb_path"] = input_reference_pdb_path
if input_ndx_path:
self.io_dict["in"]["input_ndx_path"] = input_ndx_path
# Properties
self.model_pth = input_model_pth_path
self.stats_pt = input_stats_pt_path
self.ref_pdb = input_reference_pdb_path
self.ndx = input_ndx_path
self.properties = properties
self.additional_actions = self.properties.get('additional_actions', [])
self.group = self.properties.get('group', None)
self.wholemolecules = self.properties.get('wholemolecules', None)
self.fit_to_template = self.properties.get('fit_to_template', None)
self.pytorch_model = self.properties.get('pytorch_model', None)
self.bias = self.properties.get('bias', [])
self.prints = self.properties.get('prints', {'ARG': '*', 'STRIDE': 1, 'FILE': 'COLVAR'})
# Check the properties
self.check_properties(properties)
self.check_arguments()
self.stats = self._load_stats()
self.n_features = self.stats.get('shape', [None, None])[1]
def _load_stats(self) -> Optional[Dict[str, Any]]:
"""Load stats.pt if provided."""
if self.stats_pt:
return torch.load(self.stats_pt,
weights_only=False)
return None
def _generate_features(self) -> str:
"""
Generate features.dat and return the ARG string for PYTORCH_MODEL.
Returns:
str: Comma-separated ARG string.
"""
if self.stats_pt:
# Non-Cartesian or mixed mode
return self._generate_features_from_stats(self.stats, self.io_dict['out']['output_features_dat_path'])
else:
raise ValueError('Input_stats_pt_path is required.')
def _generate_features_from_stats(self, stats: Dict[str, Any], features_path: str) -> str:
"""
Generate features.dat from stats.pt for distances, angles, dihedrals, and/or cartesians.
Args:
stats (Dict[str, Any]): Loaded stats dictionary.
features_path (str): Path to write features.dat.
Returns:
str: Comma-separated ARG string.
"""
feat_lines = []
arg_list = []
dist_count = 1
ang_count = 1
tor_count = 1
# Adjust indices to 1-based for PLUMED
def adjust_indices(indices: List[int]) -> List[int]:
return [idx + 1 for idx in indices]
if 'cartesian_indices' in stats:
pos_atoms = adjust_indices(stats['cartesian_indices'])
fu.log(f"Found {len(pos_atoms)} Cartesian features.", self.out_log)
for atom in pos_atoms:
feat_lines.append(f'p{atom}: POSITION ATOM={atom}')
arg_list.extend([f'p{atom}.x', f'p{atom}.y', f'p{atom}.z'])
if 'distance_indices' in stats:
fu.log(f"Found {len(stats['distance_indices'])} Distance features.", self.out_log)
for pair in stats['distance_indices']:
a, b = adjust_indices(pair)
label = f'd{dist_count}'
feat_lines.append(f'{label}: DISTANCE ATOMS={a},{b}')
arg_list.append(label)
dist_count += 1
if 'angle_indices' in stats:
fu.log(f"Found {len(stats['angle_indices'])} Angle features.", self.out_log)
for triple in stats['angle_indices']:
a, b, c = adjust_indices(triple)
label = f'a{ang_count}'
feat_lines.append(f'{label}: ANGLE ATOMS={a},{b},{c}')
arg_list.append(label)
ang_count += 1
if 'dihedral_indices' in stats:
fu.log(f"Found {len(stats['dihedral_indices'])} Dihedral features.", self.out_log)
for quad in stats['dihedral_indices']:
a, b, c, d = adjust_indices(quad)
label = f't{tor_count}'
feat_lines.append(f'{label}: TORSION ATOMS={a},{b},{c},{d}')
arg_list.append(label)
tor_count += 1
with open(features_path, 'w') as f:
for line in feat_lines:
f.write(line + '\n')
return feat_lines, arg_list
def _convert_model_to_ptc(self) -> None:
"""Convert the PyTorch model to TorchScript format (.ptc)."""
model = torch.load(self.model_pth, weights_only=False)
# Add this: Convert numpy.int64 attributes to Python int for JIT compatibility
def convert_attributes_to_int(m):
if hasattr(m, 'in_features'):
m.in_features = int(m.in_features)
if hasattr(m, 'out_features'):
m.out_features = int(m.out_features)
for child in m.children():
convert_attributes_to_int(child)
convert_attributes_to_int(model)
self._enable_jit_scripting(model)
output_path = self.io_dict['out']['output_model_ptc_path']
try:
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, output_path)
fu.log(f'Successfully scripted and saved model to {output_path}', self.out_log)
except Exception as e:
fu.log(f'jit.script failed: Attempting jit.trace instead.', self.out_log)
# Set to eval mode for tracing (required for BatchNorm with batch size 1)
model.eval()
example_input = torch.randn(1, self.n_features) # Batch size 1, flat input
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, output_path)
fu.log(f'Successfully traced and saved model to {output_path}', self.out_log)
def _enable_jit_scripting(self, module: torch.nn.Module) -> None:
"""Set _jit_is_scripting flag to True for the module and submodules to enable scripting."""
if hasattr(module, '_jit_is_scripting'):
module._jit_is_scripting = True
for subm in module.modules():
if hasattr(subm, '_jit_is_scripting'):
subm._jit_is_scripting = True
def _build_plumed_lines(self) -> List[str]:
"""Build the list of lines for the PLUMED file."""
lines = []
lines.append(f'INCLUDE FILE={os.path.abspath(self.io_dict["out"]["output_features_dat_path"])}')
# Additional actions (e.g., ENERGY, other metrics)
for action in self.additional_actions:
label = action.get('label', '')
if label:
label += ': '
name = action['name']
params_str = ' '.join(f'{k}={v}' for k, v in action.get('params', {}).items())
lines.append(f'{label}{name} {params_str}')
# GROUP
group_label = 'C-alpha'
if self.group:
g = self.group
group_label = g.get('label', 'C-alpha')
params = ' '.join(f'{k}={v}' for k, v in g.items() if k not in ['label', 'name'])
lines.append(f"{group_label}: GROUP {params}")
fu.log(f'Using GROUP: {group_label}', self.out_log)
fu.log(' Parameters:', self.out_log)
for k, v in g.items():
if k not in ['label', 'name']:
fu.log(f' > {k.upper()}: {v}', self.out_log)
# WHOLEMOLECULES
uses_positions = True if 'cartesian_indices' in self.stats else False
if uses_positions:
if self.wholemolecules:
w = self.wholemolecules
params = ' '.join(f'{k}={v}' for k, v in w.items())
lines.append(f'WHOLEMOLECULES {params}')
fu.log(f'Using WHOLEMOLECULES with parameters: {params}', self.out_log)
else:
fu.log('WARNING: Using Cartesian coordinates but no WHOLEMOLECULES parameters provided; add WHOLEMOLECULES in properties.', self.out_log)
else:
if self.wholemolecules:
fu.log('NOTE: Reference PDB provided but no POSITION features detected; skipping WHOLEMOLECULES.', self.out_log)
# FIT_TO_TEMPLATE
if uses_positions:
if self.fit_to_template:
f = self.fit_to_template
params = ' '.join(f'{k}={v}' for k, v in f.items())
lines.append(f'FIT_TO_TEMPLATE REFERENCE={os.path.abspath(self.ref_pdb)} {params}')
fu.log('Using FIT_TO_TEMPLATE', self.out_log)
fu.log(f' Reference PDB: {os.path.abspath(self.ref_pdb)}', self.out_log)
fu.log(' Parameters:', self.out_log)
for k, v in f.items():
fu.log(f' > {k.upper()}: {v}', self.out_log)
else:
fu.log('WARNING: Using Cartesian coordinates but no FIT_TO_TEMPLATE parameters provided; add FIT_TO_TEMPLATE in properties.', self.out_log)
else:
if self.fit_to_template:
fu.log('NOTE: Reference PDB provided but no POSITION features detected; skipping FIT_TO_TEMPLATE.', self.out_log)
# PYTORCH_MODEL
pyt_label = 'cv'
pyt_params = {'FILE': os.path.abspath(self.io_dict['out']['output_model_ptc_path']), 'ARG': self.arg}
if self.pytorch_model:
p = self.pytorch_model
pyt_label = p.get('label', 'cv')
pyt_params.update({k: v for k, v in p.items() if k not in ['label']})
params_str = ' '.join(f'{k}={v}' for k, v in pyt_params.items())
params_non_args = {f'{k}: {v}' for k, v in pyt_params.items() if k != 'ARG'}
lines.append(f'{pyt_label}: PYTORCH_MODEL {params_str}')
fu.log(f'Using PYTORCH_MODEL: {pyt_label}', self.out_log)
fu.log(f' Model ptc file: {os.path.abspath(self.io_dict["out"]["output_model_ptc_path"])}', self.out_log)
for param in params_non_args:
if not param.startswith('FILE'):
fu.log(' Parameters:', self.out_log)
fu.log(f' > {param}', self.out_log)
# Bias actions
for command in self.bias:
label = command.get('label', '')
if label:
label += ': '
name = command['name']
params_str = ' '.join(f'{k}={v}' for k, v in command.get('params', {}).items())
lines.append(f'{label}{name} {params_str}')
fu.log('Using Bias:', self.out_log)
fu.log(f' Command: {name}', self.out_log)
fu.log(' Parameters:', self.out_log)
for param in command.get('params', {}).items():
fu.log(f' > {param[0]}: {param[1]}', self.out_log)
# PRINT
prints_str = ' '.join(f'{k}={v}' for k, v in self.prints.items())
lines.append(f'PRINT {prints_str}')
return lines
[docs]
@launchlogger
def launch(self) -> int:
"""Execute the :class:`GeneratePlumed <mdae.make_plumed.GeneratePlumed>` object."""
# Setup Biobb
if self.check_restart():
return 0
self.stage_files()
# Perform model conversion and feature generation after staging files
self._convert_model_to_ptc()
features_lines, arg_list = self._generate_features()
self.arg = ','.join(arg_list)
plumed_lines = self._build_plumed_lines()
has_cartesian = True if 'cartesian_indices' in self.stats else False
if self.ndx is None:
if has_cartesian:
fu.log('WARNING: When employing Cartesian coordinates as collective variables (CVs) for biasing in PLUMED, '
'an NDX index file is required to properly define atom groups for fitting and alignment purposes, '
'make sure to provide a NDX file.', self.out_log)
fu.log(f'Generated features.dat at {os.path.abspath(self.io_dict["out"]["output_features_dat_path"])}', self.out_log)
fu.log(f'File size: {get_size(self.io_dict["out"]["output_features_dat_path"])}', self.out_log)
with open(self.io_dict['out']['output_plumed_dat_path'], 'w') as f:
f.write('\n'.join(plumed_lines) + '\n')
fu.log(f'Generated PLUMED file at {os.path.abspath(self.io_dict["out"]["output_plumed_dat_path"])}', self.out_log)
fu.log(f'File size: {get_size(self.io_dict["out"]["output_plumed_dat_path"])}', self.out_log)
# Copy files to host
self.copy_to_host()
# Remove temporal files
self.remove_tmp_files()
self.check_arguments(output_files_created=True, raise_exception=False)
return 0
[docs]
def make_plumed(
input_model_pth_path: str,
input_stats_pt_path: Optional[str] = None,
input_reference_pdb_path: Optional[str] = None,
input_ndx_path: Optional[str] = None,
output_plumed_dat_path: str = 'plumed.dat',
output_features_dat_path: str = 'features.dat',
output_model_ptc_path: str = 'model.ptc',
properties: Optional[Dict[str, Any]] = None,
**kwargs,
) -> int:
"""Create the :class:`GeneratePlumed <mdae.make_plumed.GeneratePlumed>` class and
execute the :meth:`launch() <mdae.make_plumed.GeneratePlumed.launch>` method."""
return GeneratePlumed(**dict(locals())).launch()
make_plumed.__doc__ = GeneratePlumed.__doc__
main = GeneratePlumed.get_main(make_plumed, "Generate PLUMED input for biased dynamics using an MDAE model.")
if __name__ == "__main__":
main()