Source code for parrot.train_network

"""
Core training module of PARROT

.............................................................................
idptools-parrot was developed by the Holehouse lab
     Original release ---- 2020

Question/comments/concerns? Raise an issue on github:
https://github.com/idptools/parrot

Licensed under the MIT license. 
"""

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np

from parrot import brnn_plot
from parrot import encode_sequence


[docs]def train(network, train_loader, val_loader, datatype, problem_type, weights_file, stop_condition, device, learn_rate, n_epochs, verbose=False, silent=False): """Train a BRNN and save the best performing network weights Train the network on a training set, and every epoch evaluate its performance on a validation set. Save the network weights that acheive the best performance on the validation set. User must specify the machine learning tast (`problem_type`) and the format of the data (`datatype`). Additionally, this function requires the learning rate hyperparameter and the number of epochs of training. The other hyperparameters, number of hidden layers and hidden vector size, are implictly included on the the provided network. The user may specify if they want to train the network for a set number of epochs or until an automatic stopping condition is reached with the argument `stop_condition`. Depending on the stopping condition used, the `n_epochs` argument will have a different role. Parameters ---------- network : PyTorch network object A BRNN network with the desired architecture train_loader : PyTorch DataLoader object A DataLoader containing the sequences and targets of the training set val_loader : PyTorch DataLoader object A DataLoader containing the sequences and targets of the validation set datatype : str The format of values in the dataset. Should be 'sequence' for datasets with a single value (or class label) per sequence, or 'residues' for datasets with values (or class labels) for every residue in a sequence. problem_type : str The machine learning task--should be either 'regression' or 'classification'. weights_file : str A path to the location where the best_performing network weights will be saved stop_condition : str Determines when to conclude network training. If 'iter', then the network will train for `n_epochs` epochs, then stop. If 'auto' then the network will train for at least `n_epochs` epochs, then begin assessing whether performance has sufficiently stagnated. If the performance plateaus for `n_epochs` consecutive epochs, then training will stop. device : str Location of where training will take place--should be either 'cpu' or 'cuda' (GPU). If available, training on GPU is typically much faster. learn_rate : float Initial learning rate of network training. The training process is controlled by the Adam optimization algorithm, so this learning rate will tend to decrease as training progresses. n_epochs : int Number of epochs to train for, or required to have stagnated performance for, depending on `stop_condition`. verbose : bool, optional If true, causes training updates to be written every epoch, rather than every 5 epochs. silent : bool, optional If true, causes not training updates to be written to standard out. Returns ------- list A list of the average training set losses achieved at each epoch list A list of the average validation set losses achieved at each epoch """ # Set optimizer optimizer = torch.optim.Adam(network.parameters(), lr=learn_rate) # Set loss criteria if problem_type == 'regression': if datatype == 'residues': criterion = nn.MSELoss(reduction='sum') elif datatype == 'sequence': criterion = nn.L1Loss(reduction='sum') elif problem_type == 'classification': criterion = nn.CrossEntropyLoss(reduction='sum') network = network.float() total_step = len(train_loader) min_val_loss = np.inf avg_train_losses = [] avg_val_losses = [] if stop_condition == 'auto': min_epochs = n_epochs # Set to some arbitrarily large number of iterations -- will stop automatically n_epochs = 20000000 last_decrease = 0 # Train the model - evaluate performance on val set every epoch end_training = False for epoch in range(n_epochs): # Main loop # Initialize training and testing loss for epoch train_loss = 0 val_loss = 0 # Iterate over batches for i, (names, vectors, targets) in enumerate(train_loader): vectors = vectors.to(device) targets = targets.to(device) # Forward pass outputs = network(vectors.float()) if problem_type == 'regression': loss = criterion(outputs, targets.float()) else: if datatype == 'residues': outputs = outputs.permute(0, 2, 1) loss = criterion(outputs, targets.long()) train_loss += loss.data.item() # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() for names, vectors, targets in val_loader: vectors = vectors.to(device) targets = targets.to(device) # Forward pass outputs = network(vectors.float()) if problem_type == 'regression': loss = criterion(outputs, targets.float()) else: if datatype == 'residues': outputs = outputs.permute(0, 2, 1) loss = criterion(outputs, targets.long()) # Increment val loss val_loss += loss.data.item() # Avg loss: train_loss /= len(train_loader.dataset) val_loss /= len(val_loader.dataset) signif_decrease = True if stop_condition == 'auto' and epoch > min_epochs - 1: # Check to see if loss has stopped decreasing last_epochs_loss = avg_val_losses[-min_epochs:] for loss in last_epochs_loss: if val_loss >= loss*0.995: signif_decrease = False # If network performance has plateaued over the last range of epochs, end training if not signif_decrease and epoch - last_decrease > min_epochs: end_training = True # Only save updated weights to memory if they improve val set performance if val_loss < min_val_loss: min_val_loss = val_loss # Reset min_val_loss last_decrease = epoch torch.save(network.state_dict(), weights_file) # Save model # Append losses to lists avg_train_losses.append(train_loss) avg_val_losses.append(val_loss) if verbose: print('Epoch %d\tLoss %.4f' % (epoch, val_loss)) elif epoch % 5 == 0 and silent is False: print('Epoch %d\tLoss %.4f' % (epoch, val_loss)) # This is placed here to ensure that the best network, even if the performance # improvement is marginal, is saved. if end_training: break # Return loss per epoch so that they can be plotted return avg_train_losses, avg_val_losses
[docs]def test_labeled_data(network, test_loader, datatype, problem_type, weights_file, num_classes, probabilistic_classification, include_figs, device, output_file_prefix=''): """Test a trained BRNN on labeled sequences Using the saved weights of a trained network, run a set of sequences through the network and evaluate the performancd. Return the average loss per sequence and plot the results. Testing a network on previously-unseen data provides a useful estimate of how generalizeable the network's performance is. Parameters ---------- network : PyTorch network object A BRNN network with the desired architecture test_loader : PyTorch DataLoader object A DataLoader containing the sequences and targets of the test set datatype : str The format of values in the dataset. Should be 'sequence' for datasets with a single value (or class label) per sequence, or 'residues' for datasets with values (or class labels) for every residue in a sequence. problem_type : str The machine learning task--should be either 'regression' or 'classification'. weights_file : str A path to the location of the best_performing network weights num_classes: int Number of data classes. If regression task, put 1. probabilistic_classification: bool Whether output should be binary labels, or "weights" of each label type. This field is only implemented for binary, sequence classification tasks. include_figs: bool Whether or not matplotlib figures should be generated. device : str Location of where testing will take place--should be either 'cpu' or 'cuda' (GPU). If available, training on GPU is typically much faster. output_file_prefix : str Path and filename prefix to which the test set predictions and plots will be saved. Returns ------- float The average loss across the entire test set list of lists Details of the output predictions for each of the sequences in the test set. Each inner list represents a sample in the test set, with the format: [sequence_vector, true_value, predicted_value, sequence_ID] """ # Load network weights network.load_state_dict(torch.load(weights_file)) # Get output directory for images network_filename = weights_file.split('/')[-1] output_dir = weights_file[:-len(network_filename)] # Set loss criteria if problem_type == 'regression': criterion = nn.MSELoss() elif problem_type == 'classification': criterion = nn.CrossEntropyLoss() test_loss = 0 all_targets = [] all_outputs = [] predictions = [] for names, vectors, targets in test_loader: # batch size of 1 all_targets.append(targets) vectors = vectors.to(device) targets = targets.to(device) # Forward pass outputs = network(vectors.float()) if problem_type == 'regression': loss = criterion(outputs, targets.float()) else: if datatype == 'residues': outputs = outputs.permute(0, 2, 1) loss = criterion(outputs, targets.long()) test_loss += loss.data.item() # Increment test loss all_outputs.append(outputs.detach()) # Add to list as: [seq_vector, true value, predicted value, name] predictions.append([vectors[0].cpu().numpy(), targets.cpu().numpy() [0], outputs.cpu().detach().numpy(), names[0]]) # Plot 'accuracy' depending on the problem type and datatype if problem_type == 'regression': if datatype == 'residues': if include_figs: brnn_plot.residue_regression_scatterplot(all_targets, all_outputs, output_file_prefix=output_file_prefix) # Format predictions for i in range(len(predictions)): predictions[i][2] = predictions[i][2].flatten() predictions[i][1] = predictions[i][1].flatten() elif datatype == 'sequence': if include_figs: brnn_plot.sequence_regression_scatterplot(all_targets, all_outputs, output_file_prefix=output_file_prefix) # Format predictions for i in range(len(predictions)): predictions[i][2] = predictions[i][2][0][0] predictions[i][1] = predictions[i][1][0] elif problem_type == 'classification': if datatype == 'residues': if include_figs: brnn_plot.res_confusion_matrix(all_targets, all_outputs, num_classes, output_file_prefix=output_file_prefix) # Format predictions and assign class predictions for i in range(len(predictions)): pred_values = [] for j in range(len(predictions[i][2])): pred_values = np.argmax(predictions[i][2], axis=1)[0] predictions[i][2] = np.array(pred_values, dtype=np.int) elif datatype == 'sequence': if probabilistic_classification: # Probabilistic assignment of class predictions # Optional implementation for classification tasks # e.g. every sequence is assigned probabilities # corresponding to each possible class pred_probabilities = [] for i in range(len(predictions)): softmax = np.exp(predictions[i][2][0]) probs = softmax / np.sum(softmax) predictions[i][2] = probs pred_probabilities.append(probs) # Plot ROC and PR curves if include_figs: brnn_plot.plot_roc_curve(all_targets, pred_probabilities, num_classes, output_file_prefix=output_file_prefix) brnn_plot.plot_precision_recall_curve(all_targets, pred_probabilities, num_classes, output_file_prefix=output_file_prefix) else: # Absolute assignment of class predictions # e.g. every sequence receives an integer class label for i in range(len(predictions)): pred_value = np.argmax(predictions[i][2]) predictions[i][2] = int(pred_value) # Plot confusion matrix (if not in probabilistic classification mode) if include_figs: brnn_plot.confusion_matrix(all_targets, all_outputs, num_classes, output_file_prefix=output_file_prefix) return test_loss / len(test_loader.dataset), predictions
[docs]def test_unlabeled_data(network, sequences, device, encoding_scheme='onehot', encoder=None, print_frequency=None): """Test a trained BRNN on unlabeled sequences Use a trained network to make predictions on previously-unseen data. ** Note: Unlike the previous functions, `network` here must have pre-loaded weights. ** Parameters ---------- network : PyTorch network object A BRNN network with the desired architecture and pre-loaded weights sequences : list A list of amino acid sequences to test using the network device : str Location of where testing will take place--should be either 'cpu' or 'cuda' (GPU). If available, training on GPU is typically much faster. encoding_scheme : str, optional How amino acid sequences are to be encoded as numeric vectors. Currently, 'onehot','biophysics' and 'user' are the implemented options. encoder: UserEncoder object, optional If encoding_scheme is 'user', encoder should be a UserEncoder object that can convert amino acid sequences to numeric vectors. If encoding_scheme is not 'user', use None. print_frequency : int If provided defines at what sequence interval an update is printed. Default = None. Returns ------- dict A dictionary containing predictions mapped to sequences """ pred_dict = {} local_count = -1 total_count = len(sequences) for seq in sequences: local_count = local_count + 1 if print_frequency is not None: if local_count % print_frequency == 0: print(f'On {local_count} of {total_count}') if encoding_scheme == 'onehot': seq_vector = encode_sequence.one_hot(seq) elif encoding_scheme == 'biophysics': seq_vector = encode_sequence.biophysics(seq) elif encoding_scheme == 'user': seq_vector = encoder.encode(seq) seq_vector = seq_vector.view(1, len(seq_vector), -1) # Forward pass outputs = network(seq_vector.float()).detach().numpy() pred_dict[seq] = outputs return pred_dict