Evaluating a Network with Cross-Validation:

This page walks-through a full example of using 10-fold cross-validation to validate that PARROT is training accurate and generalizeable networks. For the purposes of this example, we will use the “seq_regress_dataset.tsv” dataset found in the /data folder. For the purposes of this example, I’m going to be saving the networks and other output file to a folder named “/output”.


First, let’s generate the 10 different split-files using parrot-cvsplit. Each of these split-files will specify a different 1/10th of the dataset to be the held-out test set. The remaining 9/10ths will be partitioned randomly into training and validation sets, 80:20.

parrot-cvsplit data/seq_regress_dataset.tsv output/cv_example_splits -k 10 -t 0.8

This should generate 10 files in /output named cv_example_splits_cv[0-9].txt

10-fold CV training

Next, we want to iteratively train PARROT networks for each of these cross-validation folds using parrot-train. We will manually specify which samples should belong in the training/val/test sets by using the --split flag in conjunction with the split-files we just generated. For this example, we are going to use the hyperparameters: -nl 2 -hs 15 -lr 0.001 -b 16 and train for 250 epochs (these are relatively arbitrary decisions). It’s important here to save the output prediction files under different names, so that we can go back and analyze all of them combined at the end of network training.

If you like, you can also use the --include-figs flag to assess how each of these networks perform individually. For this example, I did not include this flag because I will do analysis at the end using some external code.

parrot-train datasets/seq_regress_dataset.tsv cv_test/network0.pt -d sequence -c 1 -nl 2 -hs 15 -lr 0.001 -b 16 -e 250 --split cv_test/cv_example_splits_cv0.txt

PARROT with user-specified parameters
Validation set loss per epoch:
Epoch 0 Loss 3.4118
Epoch 5 Loss 0.7524
Epoch 10    Loss 0.7530
Epoch 235   Loss 0.1220
Epoch 240   Loss 0.1229
Epoch 245   Loss 0.1248

Test Loss: 0.0335

Now repeat this command 9 more times, for each of the cross-validation folds. I used the --silent flag to prevent additional output to terminal, but this is totally optional.

parrot-train datasets/seq_regress_dataset.tsv cv_test/network1.pt -d sequence -c 1 -nl 2 -hs 15 -lr 0.001 -b 16 -e 250 --split cv_test/cv_example_splits_cv1.txt --silent
parrot-train datasets/seq_regress_dataset.tsv cv_test/network2.pt -d sequence -c 1 -nl 2 -hs 15 -lr 0.001 -b 16 -e 250 --split cv_test/cv_example_splits_cv2.txt --silent
parrot-train datasets/seq_regress_dataset.tsv cv_test/network9.pt -d sequence -c 1 -nl 2 -hs 15 -lr 0.001 -b 16 -e 250 --split cv_test/cv_example_splits_cv9.txt --silent

All of this could also be accomplished using a wrapper shell script. In fact, for larger datasets this is recommended so that you don’t have sit around and wait for the network to train each fold.

Analyze CV predictions

After running all of this, you should have 10 files with test set predictions: “network[0-9]_predictions.tsv”. First, we will evaluate (using Pearson’s R square)each set of predictions separately and see how the networks perform on average and how much variance there is between different networks. In published work that uses cross-validation, it’s typical to report this average performance (by some metric) across CV folds and the variance between folds.


Since all of these networks perform well, and there is low variance between predictions, we can also combine each of these prediction files and see how the cumulative predictions fare:


There’s a couple of outlier sequences, but overall it looks like our networks did great!

Train final network

Finally, now that we have a good estimate of how reliable our networks’ predictions are, we can create a new network using all of our data (and training for longer for good measure!). To do this, we simply run parrot-train again while specifying a test set size of ~0 using the --set-fractions flag (we will use 0.01, since using 0 throws an error). Importantly, validation set CANNOT be zero, as this subset is critical for making sure we do not overfit our data.

parrot-train datasets/seq_regress_dataset.tsv cv_test/final_cvnetwork.pt -d sequence -c 1 -nl 2 -hs 15 -lr 0.001 -b 16 -e 350 --set-fractions 0.79 0.2 0.01

Now if we want, we can use final_cvnetwork.pt to predict unlabeled sequences with parrot-predict.