Source code for biobb_pytorch.mdae.feat2traj

from biobb_common.generic.biobb_object import BiobbObject
from biobb_common.tools.file_utils import launchlogger
import torch
import numpy as np
import mdtraj as md
import os


[docs] class Feat2Traj(BiobbObject): """ | biobb_pytorch Feat2Traj | Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file. | Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file. Args: input_results_npz_path (str): Path to the input reconstructed results file (.npz), typically containing an 'xhat' array. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_results.npz>`_. Accepted formats: npz (edam:format_2333). input_stats_pt_path (str): Path to the input model statistics file (.pt) containing cartesian indices and optionally topology. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_model.pt>`_. Accepted formats: pt (edam:format_2333). input_topology_path (str) (optional): Path to the topology file (PDB) used if no suitable topology is found in the stats file. Used if no topology is found in stats. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/mdae/ref_input_topology.pdb>`_. Accepted formats: pdb (edam:format_1476). output_traj_path (str): Path to save the trajectory in xtc/pdb/dcd format. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.xtc>`_. Accepted formats: xtc (edam:format_3875), pdb (edam:format_1476), dcd (edam:format_3878). output_top_path (str) (optional): Path to save the output topology file (pdb). Used if trajectory format requires separate topology. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/mdae/output_model.pdb>`_. Accepted formats: pdb (edam:format_1476). properties (dict - Python dictionary object containing the tool parameters, not input/output files): * **restart** (*bool*) - (False) [WF property] Do not execute if output files exist. Examples: This example shows how to use the Feat2Traj class to convert a .pt file (features) to a trajectory using cartesian indices and topology from the stats file:: from biobb_pytorch.mdae.feat2traj import feat2traj input_results_npz_path='input_results.npz' input_stats_pt_path='input_model.pt' input_topology_path='input_topology.pdb' output_traj_path='output_model.xtc' output_top_path='output_model.pdb' prop={} feat2traj(input_results_npz_path=input_results_npz_path, input_stats_pt_path=input_stats_pt_path, input_topology_path=input_topology_path, output_traj_path=output_traj_path, output_top_path=output_top_path, properties=prop) Info: * wrapped_software: * name: PyTorch * version: >=1.6.0 * license: BSD 3-Clause * ontology: * name: EDAM * schema: http://edamontology.org/EDAM.owl """ def __init__( self, input_results_npz_path: str, input_stats_pt_path: str, input_topology_path: str = None, output_traj_path: str = None, output_top_path: str = None, properties: dict = None, **kwargs, ) -> None: properties = properties or {} super().__init__(properties) self.input_results_npz_path = input_results_npz_path self.input_stats_pt_path = input_stats_pt_path self.input_topology_path = input_topology_path self.output_traj_path = output_traj_path self.output_top_path = output_top_path self.properties = properties.copy() self.locals_var_dict = locals().copy() self.io_dict = { "in": { "input_results_npz_path": input_results_npz_path, "input_stats_pt_path": input_stats_pt_path, "input_topology_path": input_topology_path, }, "out": { "output_traj_path": output_traj_path, "output_top_path": output_top_path, }, } self.check_properties(properties) self.check_arguments()
[docs] @launchlogger def launch(self) -> int: """ Execute the :class:`Feat2Traj` class and its `.launch()` method. """ # Load features features = np.load(self.input_results_npz_path) features = features['xhat'] # Load stats and extract cartesian indices and topology stats = torch.load(self.input_stats_pt_path, weights_only=False) cartesian_indices = None topology = None if isinstance(stats, dict): if 'cartesian_indices' in stats: cartesian_indices = stats['cartesian_indices'] topology = stats['topology'] else: raise ValueError('No cartesian indices found in stats file.') cartesian_indices = np.array(cartesian_indices) n_atoms = len(cartesian_indices) n_frames = features.shape[0] coords = features.reshape((n_frames, n_atoms, 3)) # Try to use topology from stats file if present top = None if topology is not None: try: # If topology is a serialized MDTraj Topology, try to load it if isinstance(topology, md.Trajectory): top = topology.topology elif isinstance(topology, str) and os.path.exists(topology): top = md.load_topology(topology) elif isinstance(topology, dict) and 'pdb_string' in topology: import io top = md.load(io.StringIO(topology['pdb_string']), format='pdb').topology except Exception as e: print(f"Warning: Could not load topology from stats file: {e}") top = None # If not found, try input_topology_path if top is None and self.input_topology_path is not None and os.path.exists(self.input_topology_path): top = md.load_topology(self.input_topology_path) # Fallback: create a fake topology if top is None: top = md.Topology() chain = top.add_chain() res = top.add_residue('RES', chain) for i in range(n_atoms): top.add_atom('CA', element=md.element.carbon, residue=res) traj = md.Trajectory(xyz=coords, topology=top) if self.output_traj_path: ext = os.path.splitext(self.output_traj_path)[1] if ext == '.xtc': traj.save_xtc(self.output_traj_path) traj[0].save_pdb(self.output_top_path) elif ext == '.dcd': traj.save_dcd(self.output_traj_path) traj[0].save_pdb(self.output_top_path) elif ext == '.pdb': traj.save_pdb(self.output_traj_path) else: raise ValueError(f'Unknown trajectory extension: {ext}') return 0
[docs] def feat2traj( input_results_npz_path: str, input_stats_pt_path: str, input_topology_path: str = None, output_traj_path: str = None, output_top_path: str = None, properties: dict = None, **kwargs, ) -> int: """Create the :class:`Feat2Traj <Feat2Traj.Feat2Traj>` class and execute the :meth:`launch() <Feat2Traj.feat2traj.launch>` method.""" return Feat2Traj(**dict(locals())).launch()
feat2traj.__doc__ = Feat2Traj.__doc__ main = Feat2Traj.get_main(feat2traj, "Converts a .pt file (features) to a trajectory using cartesian indices and topology from the stats file.") if __name__ == "__main__": main()