route_distances.lstm package¶
Submodules¶
route_distances.lstm.data module¶
Module containing classes for loading and generating data from model training
- class route_distances.lstm.data.InMemoryTreeDataset(pairs, trees)¶
Bases:
Dataset
Represent an in-memory set of trees, and pairwise distances
- class route_distances.lstm.data.TreeDataModule(pickle_path, batch_size=128, split_part=0.1, split_seed=1984, shuffle=True)¶
Bases:
LightningDataModule
Represent a PyTorch Lightning datamodule for load and collecting data for model training
- Parameters:
pickle_path (str)
batch_size (int)
split_part (float)
split_seed (int)
shuffle (bool)
- setup(stage=None)¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Args:
stage: either
'fit'
,'validate'
,'test'
, or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- Parameters:
stage (str | None)
- Return type:
None
- train_dataloader()¶
Implement one or more PyTorch DataLoaders for training.
- Return:
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
- Note:
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- Return type:
DataLoader
- val_dataloader()¶
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
- Note:
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return:
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
- Note:
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.- Note:
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
- Return type:
DataLoader
- test_dataloader()¶
Implement one or multiple PyTorch DataLoaders for testing.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()
prepare_data()
- Note:
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Return:
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
- Note:
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.- Note:
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.
- Return type:
DataLoader
route_distances.lstm.defaults module¶
Module with defaults used in more than one place
route_distances.lstm.features module¶
Module for calculating feature and utility vectors for the LSTM-based model
- route_distances.lstm.features.add_fingerprints(tree, radius=2, nbits=2048)¶
Add Morgan fingerprints to the input tree
- Parameters:
tree (Dict[str, Any]) – the input tree
radius (int) – the radius of the Morgan calculation
nbits (int) – the length of the bitvector
- Return type:
None
- route_distances.lstm.features.remove_reactions(tree)¶
Remove reaction nodes from the input tree
Does not overwrite the original tree.
- Parameters:
tree (Dict[str, Any])
- Return type:
Dict[str, Any]
- route_distances.lstm.features.preprocess_reaction_tree(tree, nfeatures=2048)¶
Preprocess a reaction tree as produced by AiZynthFinder
- Parameters:
tree (Dict[str, Any]) – the input tree
nfeatures (int) – the number of features, i.e. fingerprint length
- Returns:
a tree that could be fed to the LSTM-based model
- Return type:
Dict[str, Any]
route_distances.lstm.inference module¶
Module containing class to make predictions of route distance matrix
- route_distances.lstm.inference.distances_calculator(model_path)¶
route_distances.lstm.models module¶
Module containing the LSTM-based model for calculation route distances
- class route_distances.lstm.models.RouteDistanceModel(fp_size=2048, lstm_size=1024, dropout_prob=0.4, learning_rate=0.001, weight_decay=0.001)¶
Bases:
LightningModule
Model for computing the distances between two synthesis routes
- Parameters:
fp_size (int) – the length of the fingerprint vector
lstm_size (int) – the size o the LSTM cell
dropout_prob (float) – the dropout probability
learning_rate (float) – the initial learning rate of the optimizer
weight_decay (float) – weight decay factor of the optimizer
- forward(tree_data)¶
Calculate the pairwise distances between the input trees
- Parameters:
tree_data (Dict[str, Any]) – collated trees from the route_distances.utils.collate_trees function.
- Returns:
the distances in condensed form
- Return type:
Tensor
- training_step(batch, _)¶
One step in the training loop
- Parameters:
batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function
_ – ignored
- Returns:
the loss tensor
- Return type:
Tensor
- validation_step(batch, _)¶
One step in the validation loop
- Parameters:
batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function
_ – ignored
- Returns:
the validation metrics
- Return type:
Dict[str, Any]
- validation_epoch_end(outputs)¶
Log the average validation metrics
- Parameters:
outputs (List[Dict[str, Any]])
- Return type:
None
- test_step(batch, _)¶
One step in the test loop
- Parameters:
batch (Dict[str, Any]) – collated pair data from the route_distances.utils.collate_batch function
_ – ignored
- Returns:
the test metrics
- Return type:
Dict[str, Any]
- test_epoch_end(outputs)¶
Log the average test metrics
- Parameters:
outputs (List[Dict[str, Any]])
- Return type:
None
- configure_optimizers()¶
Setup the Adam optimiser and scheduler
- Return type:
Tuple[List[Adam], List[Dict[str, Any]]]
route_distances.lstm.optim module¶
Module containing an objective class for Optuna optimization
- class route_distances.lstm.optim.OptunaObjective(filename)¶
Bases:
object
Representation of an objective function for Optuna
- Parameters:
filename (str) – the path to a pickle file with pre-processed trees
route_distances.lstm.utils module¶
Module for tree utilities
- route_distances.lstm.utils.accumulate_stats(stats)¶
Accumulate statistics from a list of statistics
- Parameters:
stats (List[Dict[str, float]])
- Return type:
Dict[str, float]
- route_distances.lstm.utils.add_node_index(node, n=0)¶
Add an index to the node and all its children
- Parameters:
node (Dict[str, Any])
n (int)
- Return type:
int
- route_distances.lstm.utils.collate_batch(batch)¶
Collate a batch of tree data
Collate the first tree of all pairs together, and then collate the second tree of all pairs.
Convert all matrices to pytorch tensors.
- The output dictionary has the following keys:
tree1: the collated first tree for all pairs
tree2: the collated second tree for all pairs
ted: the TED for each pair of trees
- Parameters:
batch (List[Dict[str, Any]]) – the list of tree data
- Returns:
the collated batch
- Return type:
Dict[str, Any]
- route_distances.lstm.utils.collate_trees(trees)¶
Collate a list of trees by stacking the feature vectors, the node orders and the edge orders. The adjacency list if adjusted with an offset.
This is a modified version from treelstm package that also converts all matrices to tensors
- The output dictionary has the following keys:
features: the stacked node features
node_order: the stacked node orders
edge_order: the stacked edge orders
adjacency_list: the stack and adjusted adjacency list
tree_size: the number of nodes in each tree
- Parameters:
trees (List[Dict[str, Any]]) – the trees to collate
- Returns:
the collated tree data
- Return type:
Dict[str, Any]
- route_distances.lstm.utils.gather_adjacency_list(node)¶
Create the adjacency list of a tree
- Parameters:
node (Dict[str, Any]) – the current node in the tree
- Returns:
the adjacency list
- Return type:
List[List[int]]
- route_distances.lstm.utils.gather_node_attributes(node, key)¶
Collect node attributes by recursively traversing the tree
- Parameters:
node (Dict[str, Any]) – the current node in the tree
key (str) – the name of the attribute to extract
- Returns:
the list of attributes gathered
- Return type:
List[Any]