import os
import torch
import importlib
from typing import Dict, Any, Type, Optional
from biobb_pytorch.mdae.models import __all__ as AVAILABLE_MODELS
from biobb_pytorch.mdae.utils.model_utils import assert_valid_kwargs
from biobb_pytorch.mdae.utils.log_utils import get_size
from biobb_common.tools.file_utils import launchlogger
from biobb_common.tools import file_utils as fu
from biobb_common.generic.biobb_object import BiobbObject
[docs]
class BuildModel(BiobbObject):
"""
| biobb_pytorch BuildModel
| Build a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.
| Builds a PyTorch autoencoder from the given properties.
Args:
input_stats_pt_path (str): Path to the input model statistics file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/ref_input_model.pt>`_. Accepted formats: pt (edam:format_2333).
output_model_pth_path (str) (Optional): Path to save the model in .pth format. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333).
properties (dict - Python dictionary object containing the tool parameters, not input/output files):
* **model_type** (*str*) - ("AutoEncoder") Name of the model class to instantiate (must exist in biobb_pytorch.mdae.models).
* **n_cvs** (*int*) - (1) Dimensionality of the latent space.
* **encoder_layers** (*list*) - ([16]) List of integers representing the number of neurons in each encoder layer.
* **decoder_layers** (*list*) - ([16]) List of integers representing the number of neurons in each decoder layer.
* **options** (*dict*) - ({"norm_in": {"mode": "min_max"}}) Additional options (e.g. norm_in, optimizer, loss_function, device, etc.).
Examples:
This example shows how to use the BuildModel class to build a PyTorch autoencoder model::
from biobb_pytorch.mdae.build_model import build_model
input_stats_pt_path = "input_stats.pt"
output_model_pth_file = "model.pth"
n_features = 128
prop = {
'model_type': 'AutoEncoder',
'n_cvs': 10,
'encoder_layers': [n_features, 64, 32],
'decoder_layers': [32, 64, n_features],
'options': {
'norm_in': {"mode": "min_max"},
'optimizer': {
'lr': 1e-4
}
}
}
build_model(input_stats_pt_path=input_stats_pt_path,
output_model_pth_path=None,
properties=prop)
Info:
* wrapped_software:
* name: PyTorch
* version: >=1.6.0
* license: BSD 3-Clause
* ontology:
* name: EDAM
* schema: http://edamontology.org/EDAM.owl
"""
def __init__(
self,
input_stats_pt_path: str,
output_model_pth_path: Optional[str] = None,
properties: dict = None,
**kwargs,
) -> None:
properties = properties or {}
super().__init__(properties)
self.input_stats_pt_path = input_stats_pt_path
self.output_model_pth_path = output_model_pth_path
self.props = properties.copy()
self.locals_var_dict = locals().copy()
# Input/Output files
self.io_dict = {
"in": {
"input_stats_pt_path": input_stats_pt_path,
},
"out": {}
}
if output_model_pth_path:
self.io_dict["out"]["output_model_pth_path"] = output_model_pth_path
# build the per-feature arguments
self.options: dict = properties.get("options", {})
self.model_type: str = properties.get("model_type", "AutoEncoder")
self.n_cvs: int = properties.get("n_cvs", 1)
self.encoder_layers: list = properties.get("encoder_layers", [16])
self.decoder_layers: list = properties.get("decoder_layers", [16])
self.loss_function: Optional[dict] = properties.get("loss_function", None)
self.device = self.options['device'] if 'device' in self.options else 'cpu'
# load the input files
self.stats = torch.load(self.io_dict['in']['input_stats_pt_path'],
weights_only=False)
# Check the properties
self.check_properties(properties)
self.check_arguments()
self._validate_props()
self.model = self._build_model()
self.loss_fn = self._build_loss()
# Store hyperparameters for reproducibility
hparams = {
'model_type': properties['model_type'],
'n_cvs': properties['n_cvs'],
'encoder_layers': properties['encoder_layers'],
'decoder_layers': properties['decoder_layers'],
'loss_function': self._hparams_loss_repr(),
'options': {k: v for k, v in properties['options'].items() if k != 'loss_function'}
}
setattr(self.model, '_hparams', hparams)
# Attach loss_fn and move model to device
self.model.loss_fn = self.loss_fn
self.model.to(self.device)
def _validate_props(self) -> None:
required = ['model_type', 'n_cvs', 'encoder_layers', 'decoder_layers', 'options']
missing = [k for k in required if k not in self.props]
if missing:
raise KeyError(f"Missing required properties: {missing}")
model_type = self.props['model_type']
if model_type not in AVAILABLE_MODELS:
raise ValueError(
f"Unknown model_type '{model_type}'. Available: {AVAILABLE_MODELS}"
)
def _build_model(self) -> torch.nn.Module:
module = importlib.import_module('biobb_pytorch.mdae.models')
ModelClass: Type[torch.nn.Module] = getattr(module, self.props['model_type'])
init_args = {
'n_features': self.stats['shape'][1],
'n_cvs': self.props['n_cvs'],
'encoder_layers': self.props['encoder_layers'],
'decoder_layers': self.props['decoder_layers'],
'options': {k: v for k, v in self.props['options'].items() if k not in ['loss_function', 'norm_in']}
}
if 'norm_in' in self.props.get('options', {}):
init_args['options']['norm_in'] = {
'stats': self.stats,
'mode': self.props['options']['norm_in'].get('mode')
}
assert_valid_kwargs(ModelClass, init_args, context="model init")
return ModelClass(**init_args)
def _build_loss(self) -> torch.nn.Module:
loss_config = self.props['options'].get('loss_function')
if loss_config and 'loss_type' in loss_config and loss_config['loss_type'] == 'PhysicsLoss':
loss_config['stats'] = self.stats
if not loss_config:
# Use model's default
return getattr(self.model, 'loss_fn', None)
loss_type = loss_config.get('loss_type')
if not loss_type:
raise KeyError("'loss_type' must be specified in options['loss_function']")
loss_module = importlib.import_module('biobb_pytorch.mdae.loss')
LossClass = getattr(loss_module, loss_type)
kwargs = {k: v for k, v in loss_config.items() if k != 'loss_type'}
assert_valid_kwargs(LossClass, kwargs, context="loss init")
try:
return LossClass(**kwargs)
except Exception:
kwargs = {k: v for k, v in kwargs.items() if k != 'stats'}
return LossClass(**kwargs)
def _hparams_loss_repr(self) -> str:
loss_config = self.props['options'].get('loss_function')
if loss_config:
name = loss_config.get('loss_type', '')
args = [f"{k}={v}" for k, v in loss_config.items() if k not in ['loss_type', 'stats']]
return f"{name}({', '.join(args)})"
# fallback to model's representation
return repr(getattr(self.model, 'loss_fn', ''))
[docs]
def save_weights(self, path: str) -> None:
"""Save model.state_dict() to the given path."""
torch.save(self.model.state_dict(), path)
[docs]
@classmethod
def load_weights(
cls,
props: Dict[str, Any],
path: str
) -> 'BuildModel':
"""Instantiate from props and load state_dict from path."""
inst = cls(props)
state = torch.load(path, map_location=inst.device)
inst.model.load_state_dict(state)
inst.model.to(inst.device)
return inst
[docs]
def save_full(self) -> None:
"""Serialize the full model object (including architecture)."""
torch.save(self.model, self.output_model_pth_path)
[docs]
@staticmethod
def load_full(path: str) -> torch.nn.Module:
"""Load a model serialized with save_full."""
return torch.load(path, weights_only=False)
[docs]
@launchlogger
def launch(self) -> int:
"""
Execute the :class:`BuildModel <mdae.build_model.BuildModel>` object
"""
# Setup Biobb
if self.check_restart():
return 0
self.stage_files()
if self.output_model_pth_path:
self.save_full()
fu.log("## BioBB AutoEncoder Builder ##", self.out_log)
fu.log("", self.out_log)
fu.log("Hyperparameters:", self.out_log)
fu.log("----------------", self.out_log)
for key, value in self.model.__dict__.get('_hparams').items():
if key == 'options':
fu.log(f"{key}:", self.out_log)
for sub_key, sub_value in value.items():
fu.log(f" {sub_key}: {sub_value}", self.out_log)
else:
fu.log(f"{key}: {value}", self.out_log)
fu.log("", self.out_log)
fu.log("Model:", self.out_log)
fu.log("------", self.out_log)
for line in str(self.model).splitlines():
fu.log(line, self.out_log)
fu.log("", self.out_log)
if self.output_model_pth_path:
fu.log(f"Model saved in .pth format in "
f'{os.path.abspath(self.io_dict["out"]["output_model_pth_path"])}',
self.out_log,
)
fu.log(f'File size: '
f'{get_size(self.io_dict["out"]["output_model_pth_path"])}',
self.out_log,
)
# Copy files to host
self.copy_to_host()
# Remove temporal files
self.remove_tmp_files()
self.check_arguments(output_files_created=bool(self.output_model_pth_path), raise_exception=False)
return 0
[docs]
def build_model(
properties: dict,
input_stats_pt_path: str,
output_model_pth_path: Optional[str] = None,
**kwargs,
) -> int:
"""Create the :class:`BuildModel <mdae.build_model.BuildModel>` class and
execute the :meth:`launch() <mdae.build_model.BuildModel.launch>` method."""
return BuildModel(**dict(locals())).launch()
build_model.__doc__ = BuildModel.__doc__
main = BuildModel.get_main(build_model, "Build a Molecular Dynamics AutoEncoder (MDAE) PyTorch model.")
if __name__ == "__main__":
main()
# Example usage:
# n_features = torch.rand(100, 20)
# n_feat = n_features.shape[1]
# properties = {
# 'model_type': 'VariationalAutoEncoder',
# 'n_cvs': 10,
# 'encoder_layers': [n_feat, 64, 32],
# 'decoder_layers': [32, 64, n_feat],
# 'options': {
# 'loss_function': {
# 'loss_type': 'ELBOLoss',
# 'beta': 1.0,
# 'reconstruction': 'mse',
# 'reduction': 'sum'},
# 'optimizer': {
# 'lr': 0.001
# }
# }
# }
# model_builder = BuildModel(properties)
# model_builder.save_full("test_model.pth")
# model = model_builder.load_full("test_model.pth")
# print()
# print("Hyperparameters:")
# print("----------------")
# for key, value in model._hparams.items():
# print(f"{key}: {value}")
# print()
# print("Model:")
# print("------")
# print(model)