| --- |
| license: mit |
| tags: |
| - molecular-property-prediction |
| - graph-neural-network |
| - chemistry |
| - pytorch |
| - molecular-dynamics |
| - force-fields |
| datasets: |
| - qm9 |
| - spice |
| - pfas |
| metrics: |
| - mse |
| - mae |
| pipeline_tag: graph-ml |
| library_name: moml |
| --- |
| |
| # MoML-CA: Molecular Machine Learning for Coarse-grained Applications |
|
|
| This repository contains the **DJMGNN** (Dense Jump Multi-Graph Neural Network) models from the MoML-CA project, designed for molecular property prediction and coarse-grained molecular modeling applications. |
|
|
| ## π Models Available |
|
|
| ### 1. Base Model (`base_model/`) |
| - **Pre-trained DJMGNN** model trained on multiple molecular datasets |
| - **Datasets**: QM9, SPICE, PFAS |
| - **Task**: General molecular property prediction |
| - **Use case**: Starting point for transfer learning or direct molecular property prediction |
| |
| ### 2. Fine-tuned Model (`finetuned_model/`) |
| - **PFAS-specialized DJMGNN** model fine-tuned for PFAS molecular properties |
| - **Base**: Built upon the base model |
| - **Specialization**: Per- and polyfluoroalkyl substances (PFAS) |
| - **Use case**: Optimized for PFAS molecular property prediction |
|
|
| ## ποΈ Architecture |
|
|
| **DJMGNN** (Dense Jump Multi-Graph Neural Network) features: |
| - **Multi-task learning**: Simultaneous node-level and graph-level predictions |
| - **Jump connections**: Enhanced information flow between layers |
| - **Dense blocks**: Improved gradient flow and feature reuse |
| - **Supernode aggregation**: Global graph representation |
| - **RBF features**: Radial basis function encoding for distance information |
|
|
| ### Architecture Details |
| - **Hidden Dimensions**: 128 |
| - **Number of Blocks**: 3-4 |
| - **Layers per Block**: 6 |
| - **Input Node Dimensions**: 11-29 (depending on featurization) |
| - **Node Output Dimensions**: 3 (forces/properties per atom) |
| - **Graph Output Dimensions**: 19 (molecular descriptors) |
| - **Energy Output Dimensions**: 1 (total energy) |
|
|
| ## π Training Details |
|
|
| ### Datasets |
| - **QM9**: ~130k small organic molecules with quantum mechanical properties |
| - **SPICE**: Molecular dynamics trajectories with forces and energies |
| - **PFAS**: Per- and polyfluoroalkyl substances dataset with specialized descriptors |
|
|
| ### Training Configuration |
| - **Optimizer**: Adam |
| - **Learning Rate**: 3e-5 (fine-tuning), 1e-3 (base training) |
| - **Batch Size**: 4-8 (node tasks), 8-32 (graph tasks) |
| - **Loss Functions**: MSE for regression, weighted multi-task loss |
| - **Regularization**: Dropout (0.2), gradient clipping |
|
|
| ## π§ Usage |
|
|
| ### Loading the Base Model |
|
|
| ```python |
| import torch |
| from moml.models.mgnn.djmgnn import DJMGNN |
| |
| # Initialize model architecture |
| model = DJMGNN( |
| in_node_dim=29, # Adjust based on your featurization |
| in_edge_dim=0, |
| hidden_dim=128, |
| n_blocks=4, |
| layers_per_block=6, |
| node_output_dims=3, |
| graph_output_dims=19, |
| energy_output_dims=1, |
| jk_mode="attention", |
| dropout=0.2, |
| use_supernode=True, |
| use_rbf=True, |
| rbf_K=32 |
| ) |
| |
| # Load base model checkpoint |
| checkpoint = torch.hub.load_state_dict_from_url( |
| "https://huggingface.co/saketh11/MoML-CA/resolve/main/base_model/pytorch_model.pt" |
| ) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| ``` |
|
|
| ### Loading the Fine-tuned Model |
|
|
| ```python |
| # Same architecture setup as above, then: |
| checkpoint = torch.hub.load_state_dict_from_url( |
| "https://huggingface.co/saketh11/MoML-CA/resolve/main/finetuned_model/pytorch_model.pt" |
| ) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| ``` |
|
|
| ### Making Predictions |
|
|
| ```python |
| # Assuming you have a molecular graph 'data' (torch_geometric.data.Data) |
| with torch.no_grad(): |
| output = model( |
| x=data.x, |
| edge_index=data.edge_index, |
| edge_attr=data.edge_attr, |
| batch=data.batch |
| ) |
| |
| # Extract predictions |
| node_predictions = output["node_pred"] # Per-atom properties/forces |
| graph_predictions = output["graph_pred"] # Molecular descriptors |
| energy_predictions = output["energy_pred"] # Total energy |
| ``` |
|
|
| ## π Performance |
|
|
| ### Base Model |
| - Trained on diverse molecular datasets for robust generalization |
| - Multi-task learning across node and graph-level properties |
| - Suitable for transfer learning to specialized domains |
|
|
| ### Fine-tuned Model |
| - Specialized for PFAS molecular properties |
| - Improved accuracy on fluorinated compounds |
| - Optimized for environmental and toxicological applications |
|
|
| ## π¬ Applications |
|
|
| - **Molecular Property Prediction**: HOMO/LUMO, dipole moments, polarizability |
| - **Force Field Development**: Atomic forces and energies for MD simulations |
| - **Environmental Chemistry**: PFAS behavior and properties |
| - **Drug Discovery**: Molecular screening and optimization |
| - **Materials Science**: Polymer and surface properties |
|
|
| ## π Links |
|
|
| - **GitHub Repository**: [SAKETH11111/MoML-CA](https://github.com/SAKETH11111/MoML-CA) |
| - **Documentation**: See repository README and docs/ |
| - **Issues**: Report bugs and request features on GitHub |
|
|
| ## π License |
|
|
| This project is licensed under the MIT License. See the LICENSE file for details. |
|
|
| ## π₯ Contributing |
|
|
| Contributions are welcome! Please see the contributing guidelines in the GitHub repository. |
|
|
| --- |
|
|
| *For questions or support, please open an issue in the [GitHub repository](https://github.com/SAKETH11111/MoML-CA).* |