route_distances.lstm package¶
Submodules¶
route_distances.lstm.data module¶
Module containing classes for loading and generating data from model training
- class route_distances.lstm.data.InMemoryTreeDataset(pairs, trees)¶
 Bases:
DatasetRepresent an in-memory set of trees, and pairwise distances
- class route_distances.lstm.data.TreeDataModule(pickle_path, batch_size=128, split_part=0.1, split_seed=1984, shuffle=True)¶
 Bases:
LightningDataModuleRepresent a PyTorch Lightning datamodule for load and collecting data for model training
- Parameters:
 pickle_path (str)
batch_size (int)
split_part (float)
split_seed (int)
shuffle (bool)
- setup(stage=None)¶
 Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Args:
 stage: either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- Parameters:
 stage (str | None)
- Return type:
 None
- train_dataloader()¶
 Implement one or more PyTorch DataLoaders for training.
- Return:
 A collection of
torch.utils.data.DataLoaderspecifying training samples. In the case of multiple dataloaders, please see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
- Note:
 Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- Return type:
 DataLoader
- val_dataloader()¶
 Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
- Note:
 Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return:
 A
torch.utils.data.DataLoaderor a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
- Note:
 If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.- Note:
 In the case where you return multiple validation dataloaders, the
validation_step()will have an argumentdataloader_idxwhich matches the order here.
- Return type:
 DataLoader
- test_dataloader()¶
 Implement one or multiple PyTorch DataLoaders for testing.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()
- Note:
 Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Return:
 A
torch.utils.data.DataLoaderor a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
- Note:
 If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.- Note:
 In the case where you return multiple test dataloaders, the
test_step()will have an argumentdataloader_idxwhich matches the order here.
- Return type:
 DataLoader
route_distances.lstm.defaults module¶
Module with defaults used in more than one place
route_distances.lstm.features module¶
Module for calculating feature and utility vectors for the LSTM-based model
- route_distances.lstm.features.add_fingerprints(tree, radius=2, nbits=2048)¶
 Add Morgan fingerprints to the input tree
- Parameters:
 tree (Dict[str, Any]) – the input tree
radius (int) – the radius of the Morgan calculation
nbits (int) – the length of the bitvector
- Return type:
 None
- route_distances.lstm.features.remove_reactions(tree)¶
 Remove reaction nodes from the input tree
Does not overwrite the original tree.
- Parameters:
 tree (Dict[str, Any])
- Return type:
 Dict[str, Any]
- route_distances.lstm.features.preprocess_reaction_tree(tree, nfeatures=2048)¶
 Preprocess a reaction tree as produced by AiZynthFinder
- Parameters:
 tree (Dict[str, Any]) – the input tree
nfeatures (int) – the number of features, i.e. fingerprint length
- Returns:
 a tree that could be fed to the LSTM-based model
- Return type:
 Dict[str, Any]
route_distances.lstm.inference module¶
Module containing class to make predictions of route distance matrix
- route_distances.lstm.inference.distances_calculator(model_path)¶
 
route_distances.lstm.models module¶
Module containing the LSTM-based model for calculation route distances
- class route_distances.lstm.models.RouteDistanceModel(fp_size=2048, lstm_size=1024, dropout_prob=0.4, learning_rate=0.001, weight_decay=0.001)¶
 Bases:
LightningModuleModel for computing the distances between two synthesis routes
- Parameters:
 fp_size (int) – the length of the fingerprint vector
lstm_size (int) – the size o the LSTM cell
dropout_prob (float) – the dropout probability
learning_rate (float) – the initial learning rate of the optimizer
weight_decay (float) – weight decay factor of the optimizer
- forward(tree_data)¶
 Calculate the pairwise distances between the input trees
- Parameters:
 tree_data (Dict[str, Any]) – collated trees from the route_distances.utils.collate_trees function.
- Returns:
 the distances in condensed form
- Return type:
 Tensor
- training_step(batch, _)¶
 One step in the training loop
- Parameters:
 batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function
_ – ignored
- Returns:
 the loss tensor
- Return type:
 Tensor
- validation_step(batch, _)¶
 One step in the validation loop
- Parameters:
 batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function
_ – ignored
- Returns:
 the validation metrics
- Return type:
 Dict[str, Any]
- validation_epoch_end(outputs)¶
 Log the average validation metrics
- Parameters:
 outputs (List[Dict[str, Any]])
- Return type:
 None
- test_step(batch, _)¶
 One step in the test loop
- Parameters:
 batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function
_ – ignored
- Returns:
 the test metrics
- Return type:
 Dict[str, Any]
- test_epoch_end(outputs)¶
 Log the average test metrics
- Parameters:
 outputs (List[Dict[str, Any]])
- Return type:
 None
- configure_optimizers()¶
 Setup the Adam optimiser and scheduler
- Return type:
 Tuple[List[Adam], List[Dict[str, Any]]]
route_distances.lstm.optim module¶
Module containing an objective class for Optuna optimization
- class route_distances.lstm.optim.OptunaObjective(filename)¶
 Bases:
objectRepresentation of an objective function for Optuna
- Parameters:
 filename (str) – the path to a pickle file with pre-processed trees
route_distances.lstm.utils module¶
Module for tree utilities
- route_distances.lstm.utils.accumulate_stats(stats)¶
 Accumulate statistics from a list of statistics
- Parameters:
 stats (List[Dict[str, float]])
- Return type:
 Dict[str, float]
- route_distances.lstm.utils.add_node_index(node, n=0)¶
 Add an index to the node and all its children
- Parameters:
 node (Dict[str, Any])
n (int)
- Return type:
 int
- route_distances.lstm.utils.collate_batch(batch)¶
 Collate a batch of tree data
Collate the first tree of all pairs together, and then collate the second tree of all pairs.
Convert all matrices to pytorch tensors.
- The output dictionary has the following keys:
 tree1: the collated first tree for all pairs
tree2: the collated second tree for all pairs
ted: the TED for each pair of trees
- Parameters:
 batch (List[Dict[str, Any]]) – the list of tree data
- Returns:
 the collated batch
- Return type:
 Dict[str, Any]
- route_distances.lstm.utils.collate_trees(trees)¶
 Collate a list of trees by stacking the feature vectors, the node orders and the edge orders. The adjacency list if adjusted with an offset.
This is a modified version from treelstm package that also converts all matrices to tensors
- The output dictionary has the following keys:
 features: the stacked node features
node_order: the stacked node orders
edge_order: the stacked edge orders
adjacency_list: the stack and adjusted adjacency list
tree_size: the number of nodes in each tree
- Parameters:
 trees (List[Dict[str, Any]]) – the trees to collate
- Returns:
 the collated tree data
- Return type:
 Dict[str, Any]
- route_distances.lstm.utils.gather_adjacency_list(node)¶
 Create the adjacency list of a tree
- Parameters:
 node (Dict[str, Any]) – the current node in the tree
- Returns:
 the adjacency list
- Return type:
 List[List[int]]
- route_distances.lstm.utils.gather_node_attributes(node, key)¶
 Collect node attributes by recursively traversing the tree
- Parameters:
 node (Dict[str, Any]) – the current node in the tree
key (str) – the name of the attribute to extract
- Returns:
 the list of attributes gathered
- Return type:
 List[Any]