route_distances.lstm package

Submodules

route_distances.lstm.data module

Module containing classes for loading and generating data from model training

class route_distances.lstm.data.InMemoryTreeDataset(pairs, trees)

Bases: Dataset

Represent an in-memory set of trees, and pairwise distances

class route_distances.lstm.data.TreeDataModule(pickle_path, batch_size=128, split_part=0.1, split_seed=1984, shuffle=True)

Bases: LightningDataModule

Represent a PyTorch Lightning datamodule for load and collecting data for model training

Parameters:
  • pickle_path (str)

  • batch_size (int)

  • split_part (float)

  • split_seed (int)

  • shuffle (bool)

setup(stage=None)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Args:

stage: either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
Parameters:

stage (str | None)

Return type:

None

train_dataloader()

Implement one or more PyTorch DataLoaders for training.

Return:

A collection of torch.utils.data.DataLoader specifying training samples. In the case of multiple dataloaders, please see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note:

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Example:

# single dataloader
def train_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=True
    )
    return loader

# multiple dataloaders, return as list
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a list of tensors: [batch_mnist, batch_cifar]
    return [mnist_loader, cifar_loader]

# multiple dataloader, return as dict
def train_dataloader(self):
    mnist = MNIST(...)
    cifar = CIFAR(...)
    mnist_loader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=self.batch_size, shuffle=True
    )
    cifar_loader = torch.utils.data.DataLoader(
        dataset=cifar, batch_size=self.batch_size, shuffle=True
    )
    # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
    return {'mnist': mnist_loader, 'cifar': cifar_loader}
Return type:

DataLoader

val_dataloader()

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note:

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return:

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]
Note:

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note:

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

Return type:

DataLoader

test_dataloader()

Implement one or multiple PyTorch DataLoaders for testing.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note:

Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Return:

A torch.utils.data.DataLoader or a sequence of them specifying testing samples.

Example:

def test_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
                    download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def test_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]
Note:

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Note:

In the case where you return multiple test dataloaders, the test_step() will have an argument dataloader_idx which matches the order here.

Return type:

DataLoader

route_distances.lstm.defaults module

Module with defaults used in more than one place

route_distances.lstm.features module

Module for calculating feature and utility vectors for the LSTM-based model

route_distances.lstm.features.add_fingerprints(tree, radius=2, nbits=2048)

Add Morgan fingerprints to the input tree

Parameters:
  • tree (Dict[str, Any]) – the input tree

  • radius (int) – the radius of the Morgan calculation

  • nbits (int) – the length of the bitvector

Return type:

None

route_distances.lstm.features.remove_reactions(tree)

Remove reaction nodes from the input tree

Does not overwrite the original tree.

Parameters:

tree (Dict[str, Any])

Return type:

Dict[str, Any]

route_distances.lstm.features.preprocess_reaction_tree(tree, nfeatures=2048)

Preprocess a reaction tree as produced by AiZynthFinder

Parameters:
  • tree (Dict[str, Any]) – the input tree

  • nfeatures (int) – the number of features, i.e. fingerprint length

Returns:

a tree that could be fed to the LSTM-based model

Return type:

Dict[str, Any]

route_distances.lstm.inference module

Module containing class to make predictions of route distance matrix

route_distances.lstm.inference.distances_calculator(model_path)

route_distances.lstm.models module

Module containing the LSTM-based model for calculation route distances

class route_distances.lstm.models.RouteDistanceModel(fp_size=2048, lstm_size=1024, dropout_prob=0.4, learning_rate=0.001, weight_decay=0.001)

Bases: LightningModule

Model for computing the distances between two synthesis routes

Parameters:
  • fp_size (int) – the length of the fingerprint vector

  • lstm_size (int) – the size o the LSTM cell

  • dropout_prob (float) – the dropout probability

  • learning_rate (float) – the initial learning rate of the optimizer

  • weight_decay (float) – weight decay factor of the optimizer

forward(tree_data)

Calculate the pairwise distances between the input trees

Parameters:

tree_data (Dict[str, Any]) – collated trees from the route_distances.utils.collate_trees function.

Returns:

the distances in condensed form

Return type:

Tensor

training_step(batch, _)

One step in the training loop

Parameters:
  • batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function

  • _ – ignored

Returns:

the loss tensor

Return type:

Tensor

validation_step(batch, _)

One step in the validation loop

Parameters:
  • batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function

  • _ – ignored

Returns:

the validation metrics

Return type:

Dict[str, Any]

validation_epoch_end(outputs)

Log the average validation metrics

Parameters:

outputs (List[Dict[str, Any]])

Return type:

None

test_step(batch, _)

One step in the test loop

Parameters:
  • batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function

  • _ – ignored

Returns:

the test metrics

Return type:

Dict[str, Any]

test_epoch_end(outputs)

Log the average test metrics

Parameters:

outputs (List[Dict[str, Any]])

Return type:

None

configure_optimizers()

Setup the Adam optimiser and scheduler

Return type:

Tuple[List[Adam], List[Dict[str, Any]]]

route_distances.lstm.optim module

Module containing an objective class for Optuna optimization

class route_distances.lstm.optim.OptunaObjective(filename)

Bases: object

Representation of an objective function for Optuna

Parameters:

filename (str) – the path to a pickle file with pre-processed trees

route_distances.lstm.utils module

Module for tree utilities

route_distances.lstm.utils.accumulate_stats(stats)

Accumulate statistics from a list of statistics

Parameters:

stats (List[Dict[str, float]])

Return type:

Dict[str, float]

route_distances.lstm.utils.add_node_index(node, n=0)

Add an index to the node and all its children

Parameters:
  • node (Dict[str, Any])

  • n (int)

Return type:

int

route_distances.lstm.utils.collate_batch(batch)

Collate a batch of tree data

Collate the first tree of all pairs together, and then collate the second tree of all pairs.

Convert all matrices to pytorch tensors.

The output dictionary has the following keys:
  • tree1: the collated first tree for all pairs

  • tree2: the collated second tree for all pairs

  • ted: the TED for each pair of trees

Parameters:

batch (List[Dict[str, Any]]) – the list of tree data

Returns:

the collated batch

Return type:

Dict[str, Any]

route_distances.lstm.utils.collate_trees(trees)

Collate a list of trees by stacking the feature vectors, the node orders and the edge orders. The adjacency list if adjusted with an offset.

This is a modified version from treelstm package that also converts all matrices to tensors

The output dictionary has the following keys:
  • features: the stacked node features

  • node_order: the stacked node orders

  • edge_order: the stacked edge orders

  • adjacency_list: the stack and adjusted adjacency list

  • tree_size: the number of nodes in each tree

Parameters:

trees (List[Dict[str, Any]]) – the trees to collate

Returns:

the collated tree data

Return type:

Dict[str, Any]

route_distances.lstm.utils.gather_adjacency_list(node)

Create the adjacency list of a tree

Parameters:

node (Dict[str, Any]) – the current node in the tree

Returns:

the adjacency list

Return type:

List[List[int]]

route_distances.lstm.utils.gather_node_attributes(node, key)

Collect node attributes by recursively traversing the tree

Parameters:
  • node (Dict[str, Any]) – the current node in the tree

  • key (str) – the name of the attribute to extract

Returns:

the list of attributes gathered

Return type:

List[Any]

Module contents