Source code for biobb_pytorch.mdae.train_model

import torch
import os
from typing import Optional
from biobb_common.tools.file_utils import launchlogger
from biobb_common.tools import file_utils as fu
from biobb_pytorch.mdae.utils.log_utils import get_size
from biobb_common.generic.biobb_object import BiobbObject
import lightning.pytorch.callbacks as _cbs
import lightning.pytorch.loggers as _loggers
import lightning.pytorch.profilers as _profiler
from mlcolvar.utils.trainer import MetricsCallback
import lightning
from mlcolvar.data import DictModule
from mlcolvar.data import DictDataset
import numpy as np


[docs] class TrainModel(BiobbObject): """ | biobb_pytorch TrainModel | Trains a PyTorch autoencoder using the given properties. | Trains a PyTorch autoencoder using the given properties. Args: input_model_pth_path (str): Path to the input model file. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). input_dataset_pt_path (str): Path to the input dataset file (.pt) produced by the MD feature pipeline. File type: input. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pt>`_. Accepted formats: pt (edam:format_2333). output_model_pth_path (str) (Optional): Path to save the trained model (.pth). If omitted, the trained model is only available in memory. File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.pth>`_. Accepted formats: pth (edam:format_2333). output_metrics_npz_path (str) (Optional): Path save training metrics in compressed NumPy format (.npz). File type: output. `Sample file <https://github.com/bioexcel/biobb_pytorch/raw/master/biobb_pytorch/test/reference/mdae/output_model.npz>`_. Accepted formats: npz (edam:format_2333). properties (dict - Python dictionary object containing the tool parameters, not input/output files): * **Trainer** (*dict*) - ({}) PyTorch Lightning Trainer options (e.g. max_epochs, callbacks, logger, profiler, accelerator, devices, etc.). * **Dataset** (*dict*) - ({}) mlcolvar DictDataset / DictModule options (e.g. batch_size, split proportions and shuffling flags). Examples: This example shows how to use the TrainModel class to train a PyTorch autoencoder model:: from biobb_pytorch.mdae.train_model import train_model input_model_pth_path='input_model.pth' input_dataset_pt_path='input_dataset.pt' output_model_pth_path='output_model.pth' output_metrics_npz_path='output_metrics.npz' prop={ 'Trainer': { 'max_epochs': 10, 'callbacks': { 'metrics': ['EarlyStopping'] } } }, 'Dataset': { 'batch_size': 32, 'split': { 'train_prop': 0.8, 'val_prop': 0.2 } } } train_model(input_model_pth_path=input_model_pth_path, input_dataset_pt_path=input_dataset_pt_path, output_model_pth_path=None, output_metrics_npz_path=None, properties=prop) Info: * wrapped_software: * name: PyTorch * version: >=1.6.0 * license: BSD 3-Clause * ontology: * name: EDAM * schema: http://edamontology.org/EDAM.owl """ def __init__( self, input_model_pth_path: str, input_dataset_pt_path: str, output_model_pth_path: Optional[str] = None, output_metrics_npz_path: Optional[str] = None, properties: dict = None, **kwargs, ) -> None: properties = properties or {} super().__init__(properties) self.input_model_pth_path = input_model_pth_path self.input_dataset_pt_path = input_dataset_pt_path self.output_model_pth_path = output_model_pth_path self.output_metrics_npz_path = output_metrics_npz_path self.properties = properties.copy() self.locals_var_dict = locals().copy() # Input/Output files self.io_dict = { "in": { "input_model_pth_path": input_model_pth_path, "input_dataset_pt_path": input_dataset_pt_path, }, "out": {}, } if output_model_pth_path: self.io_dict["out"]["output_model_pth_path"] = output_model_pth_path if output_metrics_npz_path: self.io_dict["out"]["output_metrics_npz_path"] = output_metrics_npz_path self.Trainer = self.properties.get('Trainer', {}) self.Dataset = self.properties.get('Dataset', {}) # Check the properties self.check_properties(properties) self.check_arguments()
[docs] def get_callbacks(self): self.colvars_metrics = MetricsCallback() cbs_list = [self.colvars_metrics] callbacks_prop = self.properties.get('Trainer', {}).get('callbacks', {}) if not callbacks_prop: return cbs_list else: for k, v in self.properties['Trainer']['callbacks'].items(): callback_params = self.properties['Trainer']['callbacks'][k] CallbackClass = getattr(_cbs, k, None) if CallbackClass: callback = CallbackClass(**callback_params) cbs_list.append(callback) return cbs_list
[docs] def get_logger(self): logger_prop = self.properties.get('Trainer', {}).get('logger', False) if not logger_prop: return None logger_type, logger_params = next(iter(logger_prop.items())) LoggerClass = getattr(_loggers, logger_type, None) if LoggerClass is None: raise KeyError(f"No Logger named {logger_type} in lightning.pytorch.loggers") return LoggerClass(**logger_params)
[docs] def get_profiler(self): profiler_prop = self.properties.get('Trainer', {}).get('profiler') if not profiler_prop: return None profiler_type, profiler_params = next(iter(profiler_prop.items())) ProfilerClass = getattr(_profiler, profiler_type, None) if ProfilerClass is None: raise KeyError(f"No Profiler named {profiler_type} in lightning.pytorch.profilers") return ProfilerClass(**profiler_params)
[docs] def get_trainer(self): train_params = {k: v for k, v in self.properties['Trainer'].items() if k not in ['callbacks', 'logger', 'profiler']} train_params['callbacks'] = self.get_callbacks() train_params['logger'] = self.get_logger() train_params['profiler'] = self.get_profiler() return lightning.Trainer(**train_params)
[docs] def load_model(self): return torch.load(self.io_dict["in"]["input_model_pth_path"], weights_only=False)
[docs] def load_dataset(self): dataset = torch.load(self.io_dict["in"]["input_dataset_pt_path"], weights_only=False) return DictDataset(dataset)
[docs] def create_datamodule(self, dataset): ds_cfg = self.properties['Dataset'] lengths = [ds_cfg['split'].get('train_prop', 0.8), ds_cfg['split'].get('val_prop', 0.2)] if ds_cfg['split'].get('test_prop', 0) > 0: lengths.append(ds_cfg['split'].get('test_prop', 0)) return DictModule( dataset, batch_size=ds_cfg.get('batch_size', 16), lengths=lengths, shuffle=ds_cfg['split'].get('shuffle', True), random_split=ds_cfg['split'].get('random_split', True) )
[docs] def fit_model(self, trainer, model, datamodule): """Fit the model to the data, capturing logs and keeping tqdm clean.""" trainer.fit(model, datamodule)
[docs] def save_full(self, model) -> None: """Serialize the full model object (including architecture).""" torch.save(model, self.io_dict["out"]["output_model_pth_path"])
[docs] @launchlogger def launch(self) -> int: """ Execute the :class:`TrainModel <mdae.train_model.TrainModel>` object. """ fu.log('## BioBB Model Trainer ##', self.out_log) # Setup Biobb if self.check_restart(): return 0 self.stage_files() # Start Pipeline # load the model fu.log(f'Load model from {os.path.abspath(self.io_dict["in"]["input_model_pth_path"])}', self.out_log) self.model = self.load_model() # load the dataset fu.log(f'Load dataset from {os.path.abspath(self.io_dict["in"]["input_dataset_pt_path"])}', self.out_log) self.dataset = self.load_dataset() # create the datamodule fu.log('Start training...', self.out_log) self.datamodule = self.create_datamodule(self.dataset) # get the trainer self.trainer = self.get_trainer() # fit the model self.fit_model(self.trainer, self.model, self.datamodule) # Set the metrics self.metrics = self.colvars_metrics.metrics # Save the metrics if path provided if self.output_metrics_npz_path: np.savez_compressed(self.io_dict["out"]["output_metrics_npz_path"], **self.metrics) fu.log(f'Training Metrics saved to {os.path.abspath(self.io_dict["out"]["output_metrics_npz_path"])}', self.out_log) fu.log(f'File size: {get_size(self.io_dict["out"]["output_metrics_npz_path"])}', self.out_log) # save the model if path provided if self.output_model_pth_path: self.save_full(self.model) fu.log(f'Trained Model saved to {os.path.abspath(self.io_dict["out"]["output_model_pth_path"])}', self.out_log) fu.log(f'File size: {get_size(self.io_dict["out"]["output_model_pth_path"])}', self.out_log) # Copy files to host self.copy_to_host() # Remove temporal files self.remove_tmp_files() output_created = bool(self.output_model_pth_path or self.output_metrics_npz_path) self.check_arguments(output_files_created=output_created, raise_exception=False) return 0
[docs] def train_model( properties: dict, input_model_pth_path: str, input_dataset_pt_path: str, output_model_pth_path: Optional[str] = None, output_metrics_npz_path: Optional[str] = None, **kwargs, ) -> int: """Create the :class:`TrainModel <mdae.train_model.TrainModel>` class and execute the :meth:`launch() <mdae.train_model.TrainModel.launch>` method.""" return TrainModel(**dict(locals())).launch()
train_model.__doc__ = TrainModel.__doc__ main = TrainModel.get_main(train_model, "Trains a PyTorch autoencoder using the given properties.") if __name__ == "__main__": main()