Graph2Mat

Interface of Graph2Mat to all architectures in metatrain.

Installation

To install this architecture along with the metatrain package, run:

pip install metatrain[graph2mat]

where the square brackets indicate that you want to install the optional dependencies required for graph2mat.

Default Hyperparameters

The description of all the hyperparameters used in graph2mat is provided further down this page. However, here we provide you with a yaml file containing all the default hyperparameters, which might be convenient as a starting point to create your own hyperparameter files:

architecture:
  name: experimental.graph2mat
  model:
    basis_yaml: .
    basis_grouping: point_type
    node_hidden_irreps: 20x0e+20x1o+20x2e
    edge_hidden_irreps: 10x0e+10x1o+10x2e
  training:
    optimizer: Adam
    optimizer_kwargs:
      lr: 0.01
    lr_scheduler: ReduceLROnPlateau
    lr_scheduler_kwargs: {}
    distributed: false
    distributed_port: 39591
    batch_size: 16
    num_epochs: 1000
    log_interval: 1
    checkpoint_interval: 100
    per_structure_targets: []
    num_workers: null
    log_mae: true
    log_separate_blocks: false
    best_model_metric: mae_prod
    grad_clip_norm: 1.0
    loss: mse

Model hyperparameters

The parameters that go under the architecture.model section of the config file are the following:

ModelHypers.basis_yaml: str = '.'

Yaml file with the full basis specification for graph2mat.

This file contains a list, with each item being a dictionary to initialize a graph2mat.PointBasis object.

ModelHypers.basis_grouping: Literal['point_type', 'basis_shape', 'max'] = 'point_type'

The way in which graph2mat should group basis (to reduce the number of heads)

ModelHypers.node_hidden_irreps: str = '20x0e+20x1o+20x2e'

Irreps to ask for to the featurizer (per atom).

Graph2Mat will take these features as input.

ModelHypers.edge_hidden_irreps: str = '10x0e+10x1o+10x2e'

Hidden irreps for the edges inside graph2mat

Trainer hyperparameters

The parameters that go under the architecture.trainer section of the config file are the following:

TrainerHypers.optimizer: str = 'Adam'

Optimizer for parameter optimization.

We just take the class from torch.optim by name, so make sure it is a valid torch optimizer (including possible uppercase/lowercase differences).

TrainerHypers.optimizer_kwargs: dict = {'lr': 0.01}

Keyword arguments to pass to the optimizer.

These will depend on the optimizer chosen.

TrainerHypers.lr_scheduler: str | None = 'ReduceLROnPlateau'

Learning rate scheduler to use.

We just take the class from torch.optim.lr_scheduler by name, so make sure it is a valid torch scheduler (including possible uppercase/lowercase differences).

None means no scheduler will be used.

TrainerHypers.lr_scheduler_kwargs: dict = {}

Keyword arguments to pass to the learning rate scheduler.

These will depend on the scheduler chosen.

TrainerHypers.distributed: bool = False

Whether to use distributed training

TrainerHypers.distributed_port: int = 39591

Port for DDP communication

TrainerHypers.batch_size: int = 16

The number of samples to use in each batch of training. This hyperparameter controls the tradeoff between training speed and memory usage. In general, larger batch sizes will lead to faster training, but might require more memory.

TrainerHypers.num_epochs: int = 1000

Number of epochs.

TrainerHypers.log_interval: int = 1

Interval to log metrics.

TrainerHypers.checkpoint_interval: int = 100

Interval to save checkpoints.

TrainerHypers.per_structure_targets: list[str] = []

Targets to calculate per-structure losses.

TrainerHypers.num_workers: int | None = None

Number of workers for data loading. If not provided, it is set automatically.

TrainerHypers.log_mae: bool = True

Log MAE alongside RMSE

TrainerHypers.log_separate_blocks: bool = False

Log per-block error.

TrainerHypers.best_model_metric: Literal['rmse_prod', 'mae_prod', 'loss'] = 'mae_prod'

Metric used to select best checkpoint (e.g., rmse_prod)

TrainerHypers.grad_clip_norm: float = 1.0

Maximum gradient norm value

TrainerHypers.loss: str | dict[str, LossSpecification] = 'mse'

This section describes the loss function to be used. See the Loss functions for more details.

References